In [1]:
import sys
from pathlib import Path

# figure out the correct path
machop_path = Path(".").resolve().parent.parent.parent /"machop"
assert machop_path.exists(), "Failed to find machop at: {}".format(machop_path)
sys.path.append(str(machop_path))

from chop.dataset import MaseDataModule, get_dataset_info
from chop.tools.logger import set_logging_verbosity

from chop.passes.graph import (
    add_common_metadata_analysis_pass,
    init_metadata_analysis_pass,
    add_software_metadata_analysis_pass,
)
from chop.tools.get_input import InputGenerator
from chop.ir.graph.mase_graph import MaseGraph

from chop.models import get_model_info, get_model

set_logging_verbosity("info")


2024-03-25 14:18:08.639842: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
[32mINFO    [0m [34mSet logging level to info[0m


In [2]:
batch_size = 256
model_name = "jsc-custom"
dataset_name = "jsc"

data_module = MaseDataModule(
    name=dataset_name,
    batch_size=batch_size,
    model_name=model_name,
    num_workers=16,
)

data_module.prepare_data()
data_module.setup()

model_info = get_model_info(model_name)
dataset_info = get_dataset_info(dataset_name)

input_generator = InputGenerator(
    data_module=data_module,
    model_info=model_info,
    task="cls",
    which_dataloader="train",
)


In [3]:
model_base = get_model(
    model_name,
    task="cls",
    dataset_info=data_module.dataset_info,
    pretrained=False
)

model_oneshot = get_model(
    model_name,
    task="cls",
    dataset_info=data_module.dataset_info,
    pretrained=False
)

model_lt = get_model(
    model_name,
    task="cls",
    dataset_info=data_module.dataset_info,
    pretrained=False
)

dummy_in = {"x": next(iter(data_module.train_dataloader()))[0]}
# generate the mase graph and initialize node metadata
mg_base = MaseGraph(model=model_base)
mg_oneshot = MaseGraph(model=model_oneshot)
mg_lt = MaseGraph(model=model_lt)

mg_base, _ = init_metadata_analysis_pass(mg_base, None)
mg_base, _ = add_common_metadata_analysis_pass(mg_base, {"dummy_in": dummy_in})
mg_base, _ = add_software_metadata_analysis_pass(mg_base, None)

mg_oneshot, _ = init_metadata_analysis_pass(mg_oneshot, None)
mg_oneshot, _ = add_common_metadata_analysis_pass(mg_oneshot, {"dummy_in": dummy_in})
mg_oneshot, _ = add_software_metadata_analysis_pass(mg_oneshot, None)

mg_lt, _ = init_metadata_analysis_pass(mg_lt, None)
mg_lt, _ = add_common_metadata_analysis_pass(mg_lt, {"dummy_in": dummy_in})
mg_lt, _ = add_software_metadata_analysis_pass(mg_lt, None)

Training the VGG network and then pruning it to see final performance

In [4]:
import copy


default_train_params = {
    "model_info": model_info,
    "data_module": data_module,
    "dataset_info": dataset_info,
    "task": "cls",
    "optimizer": "adam",
    "learning_rate": 1e-3,
    "weight_decay": 0,
    "plt_trainer_args": {
        "max_epochs": 5,
    }, 
    "auto_requeue": False,
    "save_path": None,
    "visualizer": None,
    "load_name": None,
    "load_type": None
}

base_train_params = copy.deepcopy(default_train_params)
base_train_params["model"] = mg_base.model
base_train_params["plt_trainer_args"]["max_epochs"] = 5

oneshot_train_params_1 = copy.deepcopy(default_train_params)
oneshot_train_params_1["model"] = mg_oneshot.model
oneshot_train_params_1["plt_trainer_args"]["max_epochs"] = 1

oneshot_train_params_2 = copy.deepcopy(default_train_params)
oneshot_train_params_2["model"] = mg_oneshot.model
oneshot_train_params_2["plt_trainer_args"]["max_epochs"] = 1

lt_train_params = copy.deepcopy(default_train_params)
lt_train_params["model"] = mg_lt.model
lt_train_params["plt_trainer_args"]["max_epochs"] = 1

Base model:

In [None]:
from chop.actions import test, train

train(**base_train_params)
test(**base_train_params)

Base Model after pruning:

In [None]:
from chop.passes.graph.transforms import (
    prune_transform_pass,
)

pass_args = {
    "weight":{
        "scope" : "local",
        "granularity" : "elementwise",
        "method" :  "l1-norm",
        "sparsity" : 0.7,},
    "activation":{
        "scope" : "local",
        "granularity" : "elementwise",
        "method" : "l1-norm",
        "sparsity" : 0.7,
    },
}
 
mg_base, _ = prune_transform_pass(mg_base, pass_args)

base_train_params["model"] = mg_base.model
test(**base_train_params)

One shot pruning:

In [5]:
from chop.actions import test, train
from chop.passes.graph.transforms import (
    prune_transform_pass,
)

# train(**oneshot_train_params_1)
pass_args = {
    "weight":{
        "scope" : "local",
        "granularity" : "elementwise",
        "method" :  "l1-norm",
        "sparsity" : 0.7,},
    "activation":{
        "scope" : "local",
        "granularity" : "elementwise",
        "method" : "l1-norm",
        "sparsity" : 0.7,
    },
}


In [6]:
mg_oneshot.model = oneshot_train_params_1["model"]
mg_oneshot, _ = prune_transform_pass(mg_oneshot, pass_args)

# oneshot_train_params_2["model"] = mg_oneshot.model
# train(**oneshot_train_params_2)
# test(**oneshot_train_params_2)

In [7]:
from chop.passes.graph.analysis.pruning.calculate_sparsity import add_pruning_metadata_analysis_pass


mg_oneshot, _ = add_pruning_metadata_analysis_pass(mg_oneshot, {"dummy_in": dummy_in, "add_value": True})



In [17]:
from pprint import pprint
from chop.passes.graph.utils import get_mase_op


for node in mg_oneshot.fx_graph.nodes:
    if node.name == "seq_blocks_10" and get_mase_op(node) in ['linear', 'conv2d']:
        print(node.name)
        pprint(node.meta['mase'].parameters['common'])
        pprint(node.meta['mase'].parameters['software'])

seq_blocks_10
{'args': {'bias': {'from': None,
                   'precision': [32],
                   'shape': [5],
                   'type': 'float',
                   'value': Parameter containing:
tensor([-0.1997,  0.1969,  0.0932, -0.1758, -0.1422], requires_grad=True)},
          'data_in_0': {'precision': [32],
                        'shape': [256, 16],
                        'torch_dtype': torch.float32,
                        'type': 'float',
                        'value': tensor([[0.0000, 0.0000, 0.0000,  ..., 1.3371, 3.6396, 0.0000],
        [0.6497, 0.0000, 0.5663,  ..., 0.0400, 0.0000, 0.0000],
        [0.5050, 0.1814, 0.0000,  ..., 0.5731, 0.0000, 0.0000],
        ...,
        [0.6828, 0.4171, 0.1024,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 1.8731, 0.0000],
        [0.1452, 0.0000, 0.0000,  ..., 1.9893, 2.3441, 2.0208]],
       grad_fn=<MulBackward0>)},
          'weight': {'from': None,
                     'precision': [32],

In [None]:
from chop.passes.graph.utils import get_mase_op


for node in mg_oneshot.fx_graph.nodes:
    if get_mase_op(node) in ['linear', "conv2d", "conv1d"]:
        print(node.name)
        total_w = 0
        pruned_w = 0
        module = mg_oneshot.modules[node.target]
        print(module.weight.numel())
        w = module.weight
        print("Num of masks: ", len(mg_oneshot.modules[node.target].parametrizations['weight']))

        for s in w:
            total_w += s.numel()
            pruned_w += s.numel() - s[s.nonzero()].numel()

        pruned_percent = pruned_w / total_w
        print(f"Pruned percent: {pruned_percent}")
        print(75*'-')

Lottery Ticket:

In [22]:
from chop.passes.graph.utils import get_mase_op
import copy, torch
from chop.actions import test, train
from chop.passes.graph.transforms import (
    prune_transform_pass,
)


def prune_lottery_ticket(mg, pass_args):
    overall_sparsity = pass_args["lottery_ticket"]["sparsity"]
    num_iterations = pass_args["lottery_ticket"]["num_iterations"]

    iteration_sparsity = 1 - ((1 - overall_sparsity)**(1 / num_iterations))

    data_module = MaseDataModule(
        name=pass_args["config"]["dataset"],
        batch_size=pass_args["config"]["batch_size"],
        model_name=pass_args["config"]["model"],
        num_workers=0,
    )
    data_module.prepare_data()
    data_module.setup()

    train_test_args = {
        "model": mg.model,
        "model_info": get_model_info(pass_args["config"]["model"]),
        "data_module": data_module,
        "dataset_info": get_dataset_info(pass_args["config"]["dataset"]),
        "task": pass_args["config"]["task"],
        "optimizer": pass_args["config"]["optimizer"],
        "learning_rate": pass_args["config"]["learning_rate"],
        "weight_decay": pass_args["config"]["weight_decay"],
        "plt_trainer_args": {
            "max_epochs": pass_args["config"]["max_epochs"],
        },
        "auto_requeue": False,
        "save_path": None,
        "visualizer": None,
        "load_name": None,
        "load_type": None,
    }

    prune_args = {
        "weight": {
            "scope": pass_args["lottery_ticket"]["scope"],
            "granularity": pass_args["lottery_ticket"]["granularity"],
            "method": pass_args["lottery_ticket"]["method"],
            "sparsity": iteration_sparsity,
        },
        "activation": {
            "scope": pass_args["lottery_ticket"]["scope"],
            "granularity": pass_args["lottery_ticket"]["granularity"],
            "method": pass_args["lottery_ticket"]["method"],
            "sparsity": iteration_sparsity,
        },
    }


    original_w_b = {}

    for node in mg.fx_graph.nodes:
        if get_mase_op(node) in ["linear", "conv2d", "conv1d"]:
            original_w_b[node.name] = {
                "weight": mg.modules[node.target].weight,
                "bias": mg.modules[node.target].bias,
                "meta_weight": node.meta["mase"].parameters["common"]["args"]["weight"]["value"],
                "meta_bias": node.meta["mase"].parameters["common"]["args"]["bias"]["value"],
            }
            
    for i in range(num_iterations):
        mg, _ = prune_transform_pass(mg, prune_args)

        train(**train_test_args)

        # copy the weights from the original model to the pruned model
        for node in mg.fx_graph.nodes:
            if get_mase_op(node) in ["linear", "conv2d", "conv1d"]:
                with torch.no_grad():
                    mg.modules[node.target].weight.copy_(original_w_b[node.name]['weight'])
                    # mg.modules[node.target].weight.copy_(original_w_b[node.name]['weight'])

                    mg.modules[node.target].bias.copy_(original_w_b[node.name]['bias'])

                    # update the mase metadata weights
                    node.meta["mase"].parameters["common"]["args"]["weight"]["value"] = original_w_b[node.name]['meta_weight']
                    
                    node.meta["mase"].parameters["common"]["args"]["bias"]["value"] = original_w_b[node.name]['meta_bias']

    train(**train_test_args)

    test(**train_test_args)

    return mg, {}

In [20]:
data_module = MaseDataModule(
    name=dataset_name,
    batch_size=batch_size,
    model_name=model_name,
    num_workers=16,
)

data_module.prepare_data()
data_module.setup()

model_info = get_model_info(model_name)
dataset_info = get_dataset_info(dataset_name)

input_generator = InputGenerator(
    data_module=data_module,
    model_info=model_info,
    task="cls",
    which_dataloader="train",
)

model_lt = get_model(
    model_name,
    task="cls",
    dataset_info=data_module.dataset_info,
    pretrained=False
)

dummy_in = {"x": next(iter(data_module.train_dataloader()))[0]}

mg_lt = MaseGraph(model=model_lt)

mg_lt, _ = init_metadata_analysis_pass(mg_lt, None)
mg_lt, _ = add_common_metadata_analysis_pass(mg_lt, {"dummy_in": dummy_in})
mg_lt, _ = add_software_metadata_analysis_pass(mg_lt, None)

lt_train_params['data_module'] = data_module
lt_train_params['model'] = mg_lt.model

In [21]:
pass_args = {
    "lottery_ticket": {
        "num_iterations": 5,
        "scope": "global",
        "granularity": "elementwise",
        "method": "l1-norm",
        "sparsity": 0.7
    },
    "config": {
        "dataset": "jsc",
        "model": "jsc-custom",
        "task": "cls",
        "batch_size": 256,
        "max_epochs": 1,
        "optimizer": "adam",
        "learning_rate": 1e-3,
        "weight_decay": 0,
    }
}

mg_lt, _ = prune(mg_lt, pass_args)

lt_train_params["model"] = mg_lt.model
test(**lt_train_params)

NameError: name 'prune_lottery_ticket' is not defined

In [25]:
last = []
for node in mg_lt.fx_graph.nodes:
    if get_mase_op(node) in ['linear', "conv2d", "conv1d"]:
        print(node.name)
        total_w = 0
        pruned_w = 0
        module = mg_lt.modules[node.target]
        print(module.weight.numel())
        w = module.weight
        print("Num of masks: ", len(mg_lt.modules[node.target].parametrizations['weight']))
        if node.name == "seq_blocks_10":
            for mask in mg_lt.modules[node.target].parametrizations['weight']:
                last.append(mask._buffers['mask'].detach().numpy())

        for s in w:
            total_w += s.numel()
            pruned_w += s.numel() - s[s.nonzero()].numel()

        pruned_percent = pruned_w / total_w
        print(f"Pruned percent: {pruned_percent}")
        print(75*'-')

seq_blocks_2
512
Num of masks:  5
Pruned percent: 0.6875
---------------------------------------------------------------------------
seq_blocks_6
512
Num of masks:  5
Pruned percent: 0.73046875
---------------------------------------------------------------------------
seq_blocks_10
80
Num of masks:  5
Pruned percent: 0.6
---------------------------------------------------------------------------


In [27]:
for m in last:
    m = m.flatten()
    print(len(m[m==False]) / len(m))
    print(75*'-')

0.2
---------------------------------------------------------------------------
0.3
---------------------------------------------------------------------------
0.3875
---------------------------------------------------------------------------
0.4875
---------------------------------------------------------------------------
0.6
---------------------------------------------------------------------------
