In [62]:
import sys
import logging
import os
from pathlib import Path
from pprint import pprint as pp

# figure out the correct path
machop_path = Path(".").resolve().parent.parent /"machop"
assert machop_path.exists(), "Failed to find machop at: {}".format(machop_path)
sys.path.append(str(machop_path))

from chop.tools.checkpoint_load import load_model
from chop.dataset import MaseDataModule, get_dataset_info
from chop.tools.logger import set_logging_verbosity

from chop.passes.graph.analysis import (
    report_node_meta_param_analysis_pass,
    profile_statistics_analysis_pass,
)
from chop.passes.graph import (
    add_common_metadata_analysis_pass,
    init_metadata_analysis_pass,
    add_software_metadata_analysis_pass,
)
from chop.tools.get_input import InputGenerator
from chop.ir.graph.mase_graph import MaseGraph

from chop.models import get_model_info, get_model

set_logging_verbosity("info")


[32mINFO    [0m [34mSet logging level to info[0m


In [63]:
batch_size = 512
model_name = "jsc-tiny"
dataset_name = "jsc"

data_module = MaseDataModule(
    name=dataset_name,
    batch_size=batch_size,
    model_name=model_name,
    num_workers=0,
    # custom_dataset_cache_path="../../chop/dataset"
)
data_module.prepare_data()
data_module.setup()

model_info = get_model_info(model_name)
model = get_model(
    model_name,
    task="cls",
    dataset_info=data_module.dataset_info,
    pretrained=False,
    checkpoint = None)

# LAB1_CUSTOM_PATH = "/home/bkt123/dev/advanced-deep-learning-systems/mase/mase_output/lab-1_jsc-custom/software/training_ckpts/best.ckpt"
# model = load_model(load_name=LAB1_CUSTOM_PATH, load_type="pl", model=model)

input_generator = InputGenerator(
    data_module=data_module,
    model_info=model_info,
    task="cls",
    which_dataloader="train",
    max_batches=1
)

dummy_in = next(iter(input_generator))
_ = model(**dummy_in)

# generate the mase graph and initialize node metadata
mg = MaseGraph(model=model)

In [64]:
from chop.actions import train
import torch


## Before Prunning

In [65]:
from pprint import pprint

from chop.passes.graph.utils import get_mase_op

mg, _ = init_metadata_analysis_pass(mg, None)
mg, _ = add_common_metadata_analysis_pass(mg, {"dummy_in": dummy_in})
mg, _ = add_software_metadata_analysis_pass(mg, None)

# pprint(mg.meta['mase'].__dict__)

for node in mg.fx_graph.nodes:
    if get_mase_op(node) == 'linear':
        if node.op == "call_module":
            #print(node.name)
            #print(50*'-')
            #print(node.target)
            # print(node.meta['mase'].parameters['common']['args']['data_in_0']['value'])
            # print(node.meta['mase'].parameters['common']['args']['data_in_0']['value'])
            # print(node.meta['mase'].parameters['common']['args']['weight']['value'])
            # print(node.meta['mase'].parameters['common']['results']['data_out_0']['value'])
            pprint(mg.modules[node.target].weight)
            print(50*'-')
            pprint(mg.modules[node.target].state_dict())
            
            
#print(mg.modules)
#print(type(mg.model))

Parameter containing:
tensor([[-0.0071, -0.1201, -0.2267, -0.0036,  0.0387,  0.2377, -0.0907,  0.1734,
         -0.0068, -0.1752,  0.1893,  0.1971, -0.2455, -0.0030,  0.0605,  0.0689],
        [ 0.0777, -0.0160, -0.1336,  0.1596, -0.1498,  0.0826, -0.0080, -0.1852,
         -0.1340, -0.0444,  0.2480, -0.1136, -0.0744, -0.0110, -0.0194,  0.0241],
        [-0.1639,  0.2252, -0.0322, -0.0671,  0.0073, -0.0872, -0.1262,  0.2293,
         -0.2015, -0.1364,  0.1217, -0.1235,  0.0570,  0.0812,  0.2477,  0.1620],
        [ 0.0175,  0.0241,  0.2209, -0.1608, -0.2049,  0.0016, -0.2360,  0.2398,
          0.0985, -0.0579, -0.2063, -0.1056, -0.2181, -0.2238,  0.2105,  0.0833],
        [-0.0650, -0.1742, -0.1779, -0.2297, -0.0787, -0.1829, -0.1394,  0.0014,
         -0.0916,  0.1522, -0.1246, -0.0163,  0.2452,  0.0194, -0.0331, -0.2496]],
       requires_grad=True)
--------------------------------------------------
OrderedDict([('weight',
              tensor([[-0.0071, -0.1201, -0.2267, -0.0036,  

In [66]:
from chop.passes.graph.transforms import (
    prune_transform_pass,
    activation_pruning_pass,
)
pass_args = {
    "weight":{
        "scope" : "local",
        "granularity" : "elementwise",
        "method" : "l1-norm",
        "sparsity" : 0.5,
    },
    "activation":{
        "scope" : "local",
        "granularity" : "elementwise",
        "method" : "l1-norm",
        "sparsity" : 0.5,
    },
}
 
mg, _ = prune_transform_pass(mg, pass_args)



## After Prunning

In [67]:

for node in mg.fx_graph.nodes:
    if get_mase_op(node) == 'linear':
        if node.op == "call_module":
            #print(node.name)
            #print(50*'-')
            #print(node.target)
            # print(node.meta['mase'].parameters['common']['args']['data_in_0']['value'])
            # print(node.meta['mase'].parameters['common']['args']['data_in_0']['value'])
            # print(node.meta['mase'].parameters['common']['args']['weight']['value'])
            # print(node.meta['mase'].parameters['common']['results']['data_out_0']['value'])
            pprint(mg.modules[node.target].weight)
            print(50*'-')
            pprint(mg.modules[node.target].state_dict())

tensor([[-0.0000, -0.0000, -0.2267, -0.0000,  0.0000,  0.2377, -0.0000,  0.1734,
         -0.0000, -0.1752,  0.1893,  0.1971, -0.2455, -0.0000,  0.0000,  0.0000],
        [ 0.0000, -0.0000, -0.1336,  0.1596, -0.1498,  0.0000, -0.0000, -0.1852,
         -0.1340, -0.0000,  0.2480, -0.0000, -0.0000, -0.0000, -0.0000,  0.0000],
        [-0.1639,  0.2252, -0.0000, -0.0000,  0.0000, -0.0000, -0.1262,  0.2293,
         -0.2015, -0.1364,  0.0000, -0.1235,  0.0000,  0.0000,  0.2477,  0.1620],
        [ 0.0000,  0.0000,  0.2209, -0.1608, -0.2049,  0.0000, -0.2360,  0.2398,
          0.0000, -0.0000, -0.2063, -0.0000, -0.2181, -0.2238,  0.2105,  0.0000],
        [-0.0000, -0.1742, -0.1779, -0.2297, -0.0000, -0.1829, -0.1394,  0.0000,
         -0.0000,  0.1522, -0.1246, -0.0000,  0.2452,  0.0000, -0.0000, -0.2496]],
       grad_fn=<MulBackward0>)
--------------------------------------------------
OrderedDict([('bias', tensor([ 0.0751, -0.0778,  0.0786, -0.0864, -0.1024])),
             ('parametri

## Training

In [68]:
model = mg.model
model_info = get_model_info('jsc-tiny')
dataset_info = get_dataset_info('jsc')
task = "cls"

train_params = {
    "model": model,
    "model_info": model_info,
    "data_module": data_module,
    "dataset_info": dataset_info,
    "task": task,
    "optimizer": "adam",
    "learning_rate": 1e-3,
    "weight_decay": 0,
    "plt_trainer_args": {
        "max_epochs": 1,
    }, 
    "auto_requeue": False,
    "save_path": None,
    "visualizer": None,
    "load_name": None,
    "load_type": None
}

train(**train_params)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type               | Params
-------------------------------------------------
0 | model     | GraphModule        | 117   
1 | loss_fn   | CrossEntropyLoss   | 0     
2 | acc_train | MulticlassAccuracy | 0     
3 | acc_val   | MulticlassAccuracy | 0     
4 | acc_test  | MulticlassAccuracy | 0     
5 | loss_val  | MeanMetric         | 0     
6 | loss_test | MeanMetric         | 0     
-------------------------------------------------
117       Trainable params
0         Non-trainable params
117       Total params
0.000     Total estimated model params size (MB)


                                                                           

/home/agomotto3000/anaconda3/envs/mase/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.
/home/agomotto3000/anaconda3/envs/mase/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


Epoch 0: 100%|██████████| 1542/1542 [00:31<00:00, 49.04it/s, v_num=30, train_acc_step=0.445, val_acc_epoch=0.456, val_loss_epoch=1.300]

`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████| 1542/1542 [00:31<00:00, 49.03it/s, v_num=30, train_acc_step=0.445, val_acc_epoch=0.456, val_loss_epoch=1.300]


## Channelwise

In [79]:
batch_size = 512
model_name = "toy_convnet"
dataset_name = "Cifar10"

data_module = MaseDataModule(
    name=dataset_name,
    batch_size=batch_size,
    model_name=model_name,
    num_workers=0,
    # custom_dataset_cache_path="../../chop/dataset"
)
data_module.prepare_data()
data_module.setup()

model_info = get_model_info(model_name)
model = get_model(
    model_name,
    task="cls",
    dataset_info=data_module.dataset_info,
    pretrained=False,
    checkpoint = None)

# LAB1_CUSTOM_PATH = "/home/bkt123/dev/advanced-deep-learning-systems/mase/mase_output/lab-1_jsc-custom/software/training_ckpts/best.ckpt"
# model = load_model(load_name=LAB1_CUSTOM_PATH, load_type="pl", model=model)

input_generator = InputGenerator(
    data_module=data_module,
    model_info=model_info,
    task="cls",
    which_dataloader="train",
    max_batches=1
)

dummy_in = next(iter(input_generator))
_ = model(**dummy_in)

# generate the mase graph and initialize node metadata
mg = MaseGraph(model=model)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


## Before Prunning

In [80]:
from pprint import pprint

from chop.passes.graph.utils import get_mase_op

mg, _ = init_metadata_analysis_pass(mg, None)
mg, _ = add_common_metadata_analysis_pass(mg, {"dummy_in": dummy_in})
mg, _ = add_software_metadata_analysis_pass(mg, None)

# pprint(mg.meta['mase'].__dict__)

for node in mg.fx_graph.nodes:
    if get_mase_op(node) == 'conv2d':
        if node.op == "call_module":
            pprint(mg.modules[node.target].weight)
            print(50*'-')
            pprint(mg.modules[node.target].state_dict())

Parameter containing:
tensor([[[[ 1.3748e-01,  2.1479e-02, -9.4732e-02],
          [ 2.9365e-02,  1.4427e-01, -8.5327e-02],
          [ 1.6447e-02,  8.1002e-02, -2.6879e-02]],

         [[-1.3203e-01, -4.0379e-02,  1.2712e-01],
          [-1.8229e-01, -1.4738e-02, -5.4161e-02],
          [-2.4612e-02, -1.1651e-01, -3.0695e-02]],

         [[ 1.2624e-01, -1.0735e-02,  1.4186e-01],
          [ 9.0847e-02, -1.2447e-01,  1.7536e-01],
          [ 1.0018e-01,  8.5701e-02,  1.7813e-03]]],


        [[[ 6.9845e-02,  1.1027e-01,  1.1402e-01],
          [-1.0543e-01,  1.2811e-01, -1.7400e-01],
          [ 1.2780e-01,  1.2102e-01, -9.4646e-02]],

         [[ 4.9147e-02,  1.9159e-01, -2.3558e-02],
          [-1.8466e-01,  1.4989e-01, -1.3137e-01],
          [-6.7507e-02,  8.6385e-02,  5.5279e-03]],

         [[ 2.8780e-02, -1.3954e-01,  1.6852e-01],
          [ 3.2005e-02, -1.9144e-01,  9.4916e-02],
          [-2.4395e-03,  1.1808e-01, -1.2889e-01]]],


        [[[-2.2024e-05, -2.7090e-03, -3.6521

In [84]:
pass_args = {
    "weight":{
        "scope" : "local",
        "granularity" : "channel",
        "method" : "l1-norm",
        "sparsity" : 0.8,
    },
    "activation":{
        "scope" : "local",
        "granularity" : "elementwise",
        "method" : "l1-norm",
        "sparsity" : 0.5,
    },
}
 
mg, _ = prune_transform_pass(mg, pass_args)

4
4
3
3
2


In [85]:
for node in mg.fx_graph.nodes:
    if get_mase_op(node) == 'conv2d':
        if node.op == "call_module":
            pprint(mg.modules[node.target].weight)
            print(50*'-')
            pprint(mg.modules[node.target].state_dict())

tensor([[[[ 0.0000,  0.0000, -0.0000],
          [ 0.0000,  0.0000, -0.0000],
          [ 0.0000,  0.0000, -0.0000]],

         [[-0.0000, -0.0000,  0.0000],
          [-0.0000, -0.0000, -0.0000],
          [-0.0000, -0.0000, -0.0000]],

         [[ 0.0000, -0.0000,  0.0000],
          [ 0.0000, -0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000]]],


        [[[ 0.0698,  0.1103,  0.1140],
          [-0.1054,  0.1281, -0.1740],
          [ 0.1278,  0.1210, -0.0946]],

         [[ 0.0491,  0.1916, -0.0236],
          [-0.1847,  0.1499, -0.1314],
          [-0.0675,  0.0864,  0.0055]],

         [[ 0.0288, -0.1395,  0.1685],
          [ 0.0320, -0.1914,  0.0949],
          [-0.0024,  0.1181, -0.1289]]],


        [[[-0.0000, -0.0000, -0.0000],
          [ 0.0000, -0.0000, -0.0000],
          [-0.0000, -0.0000, -0.0000]],

         [[ 0.0000,  0.0000, -0.0000],
          [-0.0000,  0.0000, -0.0000],
          [ 0.0000, -0.0000, -0.0000]],

         [[-0.0000, -0.0000, -0.0000],
     

: 

In [71]:
from pprint import pprint

from chop.passes.graph.utils import get_mase_op

# pprint(mg.meta['mase'].__dict__)

for node in mg.fx_graph.nodes:
    #if get_mase_op(node) == 'linear':
    if node.op == "call_module":
        #print(f"Layer: {node.name}")
        # pprint(node.meta['mase'].parameters['common'])
        # pprint(node.meta['mase'].parameters['common']['args']['data_in_0']['value'])
        #pprint(node.meta['mase'].parameters['common']['args'])
        pprint(mg.modules[node.target]._forward_pre_hooks.items())



odict_items([])
odict_items([])
odict_items([])
odict_items([])
odict_items([])
odict_items([])
odict_items([])
odict_items([])
odict_items([])
odict_items([])
odict_items([])


In [72]:
from pprint import pprint

from chop.passes.graph.utils import get_mase_op

# pprint(mg.meta['mase'].__dict__)

for node in mg.fx_graph.nodes:
    if get_mase_op(node) == 'linear':
        print(f"Layer: {node.name}")
        # pprint(node.meta['mase'].parameters['common'])
        # pprint(node.meta['mase'].parameters['common']['args']['data_in_0']['value'])
        pprint(node.meta['mase'].parameters['common']['args'])
        #pprint(mg.modules[node.target].weight)
        # pprint(mg.modules[node.target].parametrizations['weight'][0].mask)
        # pprint(node.meta['mase'].parameters['common']['results']['data_out_0']['value'])

        # total_w = 0
        # pruned_w = 0
        # mask_2= mg.modules[node.target].parametrizations['weight'][0].mask
        # for s in mask_2:
        #     total_w += s.numel()
        #     pruned_w += s.numel() - s.nonzero().numel()

        # pruned_percent = pruned_w / total_w
        # print(f"Pruned percent: {pruned_percent}")

        print(50*'-')
#print(mg.model.state_dict())

Layer: linear
{'bias': {'from': None,
          'precision': [32],
          'shape': [10],
          'type': 'float',
          'value': Parameter containing:
tensor([-0.0445,  0.0786,  0.1041, -0.0927,  0.0147,  0.0457,  0.1057, -0.0290,
         0.0905,  0.0922], requires_grad=True)},
 'data_in_0': {'precision': [32],
               'shape': [512, 64],
               'torch_dtype': torch.float32,
               'type': 'float',
               'value': tensor([[4.4416e-02, 0.0000e+00, 1.4396e-01,  ..., 4.4614e-04, 5.0736e-02,
         0.0000e+00],
        [3.8386e-02, 0.0000e+00, 1.9087e-01,  ..., 6.7570e-03, 4.6465e-02,
         9.0773e-06],
        [2.2434e-02, 0.0000e+00, 1.6967e-01,  ..., 6.1829e-04, 3.0395e-02,
         0.0000e+00],
        ...,
        [1.4487e-02, 0.0000e+00, 2.2211e-01,  ..., 8.9831e-05, 9.8996e-03,
         0.0000e+00],
        [4.6827e-02, 0.0000e+00, 1.6343e-01,  ..., 3.0457e-03, 5.0189e-02,
         0.0000e+00],
        [2.3559e-02, 0.0000e+00, 1.7879e-01

In [73]:
from chop.passes.graph.transforms import (
    prune_transform_pass,
)
pass_args = {
    "weight":{
        "scope" : "global",
        "granularity" : "elementwise",
        "method" :  "l1-norm",
        "sparsity" : 0.6,
    },
    "activation":{
        "scope" : "local",
        "granularity" : "elementwise",
        "method" : "l1-norm",
        "sparsity" : 0.6,
    },
}
 
mg, _ = prune_transform_pass(mg, pass_args)


In [74]:
print(mg.model.state_dict())
#print(node.meta['mase'].parameters['common']['args']['weight']['value'])

OrderedDict([('block_1.0.bias', tensor([ 0.1008,  0.1006,  0.0213,  0.0648, -0.0629, -0.1817, -0.0902, -0.1284])), ('block_1.0.parametrizations.weight.original', tensor([[[[ 0.0944,  0.0471,  0.0161],
          [-0.1590,  0.0563,  0.1508],
          [ 0.1106,  0.1515, -0.0847]],

         [[ 0.0242, -0.0580, -0.0695],
          [-0.0390, -0.0946,  0.0599],
          [ 0.0066,  0.1182,  0.0722]],

         [[ 0.0477,  0.1016,  0.0493],
          [ 0.0980, -0.0535, -0.0730],
          [-0.0436, -0.0778,  0.1692]]],


        [[[-0.0118,  0.0479, -0.1685],
          [-0.1033,  0.1019,  0.1909],
          [-0.0941, -0.1594,  0.0562]],

         [[ 0.1403, -0.1430, -0.1551],
          [ 0.1700, -0.1056, -0.0765],
          [-0.1542, -0.1113, -0.1068]],

         [[-0.0073, -0.0518, -0.0062],
          [-0.1183, -0.0293,  0.0502],
          [ 0.0126,  0.0405,  0.0821]]],


        [[[ 0.1837,  0.0905, -0.0382],
          [ 0.1190, -0.1652, -0.0091],
          [ 0.1463,  0.0432,  0.0725]],

 

In [75]:
# from chop.actions import train
# import torch

# # print(isinstance(mg.model, torch.nn.Module))

# model = mg.model
# model_info = get_model_info('jsc-tiny')
# dataset_info = get_dataset_info('jsc')
# task = "cls"

# train_params = {
#     "model": model,
#     "model_info": model_info,
#     "data_module": data_module,
#     "dataset_info": dataset_info,
#     "task": task,
#     "optimizer": "adam",
#     "learning_rate": 1e-3,
#     "weight_decay": 0,
#     "plt_trainer_args": {
#         "max_epochs": 1,
#     }, 
#     "auto_requeue": False,
#     "save_path": None,
#     "visualizer": None,
#     "load_name": None,
#     "load_type": None
# }

# train(**train_params)

In [76]:
# from pprint import pprint

# from chop.passes.graph.utils import get_mase_op


# # pprint(mg.meta['mase'].__dict__)

for node in mg.fx_graph.nodes:
    if get_mase_op(node) == 'linear':
        print(node.name)
        print(50*'-')
        # pprint(node.meta['mase'].parameters['common'])
        # pprint(node.meta['mase'].parameters['common']['args']['data_in_0']['value'])
        pprint(node.meta['mase'].parameters['common']['args']['weight']['value'])
        #pprint(mg.modules[node.target].weight)
        #pprint(mg.modules[node.target].parametrizations['weight'][0].mask)
        # pprint(node.meta['mase'].parameters['common']['results']['data_out_0']['value'])

        print(50*'-')

linear
--------------------------------------------------
Parameter containing:
tensor([[ 0.0934,  0.1204,  0.0324, -0.0053, -0.0041,  0.0687,  0.0620, -0.0787,
          0.1006, -0.0579,  0.0668,  0.1244, -0.0975,  0.0899, -0.0715,  0.0163,
          0.0927,  0.0796, -0.1143,  0.0537,  0.0397,  0.0691, -0.0862,  0.1157,
          0.0113, -0.1061,  0.0944,  0.0460, -0.0097, -0.0417,  0.0521,  0.0446,
         -0.1070,  0.1065,  0.0362,  0.0056,  0.0176, -0.0080,  0.0745, -0.0038,
         -0.0722,  0.1062, -0.1182, -0.0822,  0.0119,  0.0928, -0.0454, -0.0541,
          0.0584,  0.0731, -0.0997,  0.0477, -0.0369,  0.0539,  0.0151,  0.0952,
          0.0933, -0.0085,  0.0990,  0.0674,  0.0712,  0.0218, -0.1189,  0.0141],
        [-0.1125,  0.1211, -0.1075, -0.0887, -0.1094,  0.0586,  0.1129, -0.0207,
         -0.0451,  0.0991, -0.0577,  0.0336,  0.0432, -0.0425, -0.0087, -0.0105,
         -0.0011, -0.0081,  0.0166, -0.0172, -0.0564, -0.0077,  0.0975,  0.0683,
         -0.1064,  0.0903, -