In [1]:
import sys
from pathlib import Path

# figure out the correct path
machop_path = Path(".").resolve().parent.parent /"machop"
assert machop_path.exists(), "Failed to find machop at: {}".format(machop_path)
sys.path.append(str(machop_path))

from chop.dataset import MaseDataModule, get_dataset_info
from chop.tools.logger import set_logging_verbosity

from chop.passes.graph import (
    add_common_metadata_analysis_pass,
    init_metadata_analysis_pass,
    add_software_metadata_analysis_pass,
)
from chop.tools.get_input import InputGenerator
from chop.ir.graph.mase_graph import MaseGraph

from chop.models import get_model_info, get_model

set_logging_verbosity("info")


[32mINFO    [0m [34mSet logging level to info[0m


In [2]:
batch_size = 256
model_name = "vgg7"
dataset_name = "cifar10"

data_module = MaseDataModule(
    name=dataset_name,
    batch_size=batch_size,
    model_name=model_name,
    num_workers=0,
)

data_module.prepare_data()
data_module.setup()

model_info = get_model_info(model_name)
dataset_info = get_dataset_info(dataset_name)

input_generator = InputGenerator(
    data_module=data_module,
    model_info=model_info,
    task="cls",
    which_dataloader="train",
)


Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [3]:
model_base = get_model(
    model_name,
    task="cls",
    dataset_info=data_module.dataset_info,
    pretrained=False
)

model_oneshot = get_model(
    model_name,
    task="cls",
    dataset_info=data_module.dataset_info,
    pretrained=False
)

model_lt = get_model(
    model_name,
    task="cls",
    dataset_info=data_module.dataset_info,
    pretrained=False
)

dummy_in = {"x": next(iter(data_module.train_dataloader()))[0]}
# generate the mase graph and initialize node metadata
mg_base = MaseGraph(model=model_base)
mg_oneshot = MaseGraph(model=model_oneshot)
mg_lt = MaseGraph(model=model_lt)

mg_base, _ = init_metadata_analysis_pass(mg_base, None)
mg_base, _ = add_common_metadata_analysis_pass(mg_base, {"dummy_in": dummy_in})
mg_base, _ = add_software_metadata_analysis_pass(mg_base, None)

mg_oneshot, _ = init_metadata_analysis_pass(mg_oneshot, None)
mg_oneshot, _ = add_common_metadata_analysis_pass(mg_oneshot, {"dummy_in": dummy_in})
mg_oneshot, _ = add_software_metadata_analysis_pass(mg_oneshot, None)

mg_lt, _ = init_metadata_analysis_pass(mg_lt, None)
mg_lt, _ = add_common_metadata_analysis_pass(mg_lt, {"dummy_in": dummy_in})
mg_lt, _ = add_software_metadata_analysis_pass(mg_lt, None)

Training the VGG network and then pruning it to see final performance

In [4]:
default_train_params = {
    "model_info": model_info,
    "data_module": data_module,
    "dataset_info": dataset_info,
    "task": "cls",
    "optimizer": "adam",
    "learning_rate": 1e-3,
    "weight_decay": 0,
    "plt_trainer_args": {
        "max_epochs": 5,
    }, 
    "auto_requeue": False,
    "save_path": None,
    "visualizer": None,
    "load_name": None,
    "load_type": None
}

base_train_params = default_train_params.copy()
base_train_params["model"] = mg_base.model
base_train_params["plt_trainer_args"]["max_epochs"] = 5

oneshot_train_params_1 = default_train_params.copy()
oneshot_train_params_1["model"] = mg_oneshot.model
oneshot_train_params_1["plt_trainer_args"]["max_epochs"] = 3

oneshot_train_params_2 = default_train_params.copy()
oneshot_train_params_2["model"] = mg_oneshot.model
oneshot_train_params_2["plt_trainer_args"]["max_epochs"] = 2

lt_train_params = default_train_params.copy()
lt_train_params["model"] = mg_lt.model
lt_train_params["plt_trainer_args"]["max_epochs"] = 1

Base model:

In [5]:
from chop.actions import test, train

train(**base_train_params)
test(**base_train_params)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified



  | Name      | Type               | Params
-------------------------------------------------
0 | model     | GraphModule        | 14.0 M
1 | loss_fn   | CrossEntropyLoss   | 0     
2 | acc_train | MulticlassAccuracy | 0     
3 | acc_val   | MulticlassAccuracy | 0     
4 | acc_test  | MulticlassAccuracy | 0     
5 | loss_val  | MeanMetric         | 0     
6 | loss_test | MeanMetric         | 0     
-------------------------------------------------
14.0 M    Trainable params
0         Non-trainable params
14.0 M    Total params
56.118    Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/bkt123/anaconda3/envs/mase/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=47` in the `DataLoader` to improve performance.
/home/bkt123/anaconda3/envs/mase/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=47` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=1` reached.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


/home/bkt123/anaconda3/envs/mase/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=47` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     test_acc_epoch          0.289900004863739
     test_loss_epoch         1.84405517578125
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


Base Model after pruning:

In [6]:
from chop.passes.graph.transforms import (
    prune_transform_pass,
)

pass_args = {
    "weight":{
        "scope" : "global",
        "granularity" : "elementwise",
        "method" :  "l1-norm",
        "sparsity" : 0.7,},
    "activation":{
        "scope" : "global",
        "granularity" : "elementwise",
        "method" : "l1-norm",
        "sparsity" : 0.7,
    },
}
 
mg_base, _ = prune_transform_pass(mg_base, pass_args)

base_train_params["model"] = mg_base.model
test(**base_train_params)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


Testing: |          | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     test_acc_epoch         0.22779999673366547
     test_loss_epoch        2.0580010414123535
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


One shot pruning:

In [7]:
train(**oneshot_train_params_1)

mg_oneshot, _ = prune_transform_pass(mg_oneshot, pass_args)

oneshot_train_params_2["model"] = mg_oneshot.model
train(**oneshot_train_params_2)
test(**oneshot_train_params_2)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified



  | Name      | Type               | Params
-------------------------------------------------
0 | model     | GraphModule        | 14.0 M
1 | loss_fn   | CrossEntropyLoss   | 0     
2 | acc_train | MulticlassAccuracy | 0     
3 | acc_val   | MulticlassAccuracy | 0     
4 | acc_test  | MulticlassAccuracy | 0     
5 | loss_val  | MeanMetric         | 0     
6 | loss_test | MeanMetric         | 0     
-------------------------------------------------
14.0 M    Trainable params
0         Non-trainable params
14.0 M    Total params
56.118    Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=1` reached.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified



  | Name      | Type               | Params
-------------------------------------------------
0 | model     | GraphModule        | 14.0 M
1 | loss_fn   | CrossEntropyLoss   | 0     
2 | acc_train | MulticlassAccuracy | 0     
3 | acc_val   | MulticlassAccuracy | 0     
4 | acc_test  | MulticlassAccuracy | 0     
5 | loss_val  | MeanMetric         | 0     
6 | loss_test | MeanMetric         | 0     
-------------------------------------------------
14.0 M    Trainable params
0         Non-trainable params
14.0 M    Total params
56.118    Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=1` reached.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


Testing: |          | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     test_acc_epoch         0.46970000863075256
     test_loss_epoch        1.4413295984268188
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


Lottery Ticket:

In [8]:
from chop.passes.graph.utils import get_mase_op
import copy, torch


def prune_lottery_ticket(mg, pass_args):
    overall_sparsity = pass_args["lottery_ticket"]["sparsity"]
    num_iterations = pass_args["lottery_ticket"]["num_iterations"]

    iteration_sparsity = 1 - ((1 - overall_sparsity) ** (1 / num_iterations))

    data_module = MaseDataModule(
        name=pass_args["config"]["dataset"],
        batch_size=pass_args["config"]["batch_size"],
        model_name=pass_args["config"]["model"],
        num_workers=0,
    )
    data_module.prepare_data()
    data_module.setup()

    train_test_args = {
        "model": mg.model,
        "model_info": get_model_info(pass_args["config"]["model"]),
        "data_module": data_module,
        "dataset_info": get_dataset_info(pass_args["config"]["dataset"]),
        "task": pass_args["config"]["task"],
        "optimizer": pass_args["config"]["optimizer"],
        "learning_rate": pass_args["config"]["learning_rate"],
        "weight_decay": pass_args["config"]["weight_decay"],
        "plt_trainer_args": {
            "max_epochs": pass_args["config"]["max_epochs"],
        },
        "auto_requeue": False,
        "save_path": None,
        "visualizer": None,
        "load_name": None,
        "load_type": None,
    }

    prune_args = {
        "weight": {
            "scope": pass_args["lottery_ticket"]["scope"],
            "granularity": pass_args["lottery_ticket"]["granularity"],
            "method": pass_args["lottery_ticket"]["method"],
            "sparsity": iteration_sparsity,
        },
        "activation": {
            "scope": pass_args["lottery_ticket"]["scope"],
            "granularity": pass_args["lottery_ticket"]["granularity"],
            "method": pass_args["lottery_ticket"]["method"],
            "sparsity": iteration_sparsity,
        },
    }


    original_w_b = {}

    for node in mg.fx_graph.nodes:
        if get_mase_op(node) in ["linear", "conv2d", "conv1d"]:
            original_w_b[node.name] = {
                "weight": mg.modules[node.target].weight,
                "bias": mg.modules[node.target].bias,
                "meta_weight": node.meta["mase"].parameters["common"]["args"]["weight"]["value"],
                "meta_bias": node.meta["mase"].parameters["common"]["args"]["bias"]["value"],
            }
            
    for i in range(num_iterations):
        mg, _ = prune_transform_pass(mg, prune_args)

        train(**train_test_args)

        # copy the weights from the original model to the pruned model
        for node in mg.fx_graph.nodes:
            if get_mase_op(node) in ["linear", "conv2d", "conv1d"]:
                with torch.no_grad():
                    mg.modules[node.target].weight.copy_(original_w_b[node.name]['weight'])
                    # mg.modules[node.target].weight.copy_(original_w_b[node.name]['weight'])

                    mg.modules[node.target].bias.copy_(original_w_b[node.name]['bias'])

                    # update the mase metadata weights
                    node.meta["mase"].parameters["common"]["args"]["weight"]["value"] = original_w_b[node.name]['meta_weight']
                    
                    node.meta["mase"].parameters["common"]["args"]["bias"]["value"] = original_w_b[node.name]['meta_bias']

    train(**train_test_args)

    test(**train_test_args)

    return mg, {}

In [None]:
data_module = MaseDataModule(
    name=dataset_name,
    batch_size=batch_size,
    model_name=model_name,
    num_workers=0,
)

data_module.prepare_data()
data_module.setup()

model_info = get_model_info(model_name)
dataset_info = get_dataset_info(dataset_name)

input_generator = InputGenerator(
    data_module=data_module,
    model_info=model_info,
    task="cls",
    which_dataloader="train",
)

lt_train_params['data_module'] = data_module

In [10]:
pass_args = {
    "lottery_ticket": {
        "num_iterations": 5,
        "scope": "global",
        "granularity": "elementwise",
        "method": "l1-norm",
        "sparsity": 0.7
    },
    "config": {
        "dataset": "cifar10",
        "model": "vgg7",
        "task": "cls",
        "batch_size": 256,
        "max_epochs": 1,
        "optimizer": "adam",
        "learning_rate": 1e-3,
        "weight_decay": 0,
    }
}

mg_lt, _ = prune_lottery_ticket(mg_lt, pass_args)

lt_train_params["model"] = mg_lt.model
test(**lt_train_params)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified



  | Name      | Type               | Params
-------------------------------------------------
0 | model     | GraphModule        | 14.0 M
1 | loss_fn   | CrossEntropyLoss   | 0     
2 | acc_train | MulticlassAccuracy | 0     
3 | acc_val   | MulticlassAccuracy | 0     
4 | acc_test  | MulticlassAccuracy | 0     
5 | loss_val  | MeanMetric         | 0     
6 | loss_test | MeanMetric         | 0     
-------------------------------------------------
14.0 M    Trainable params
0         Non-trainable params
14.0 M    Total params
56.118    Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/bkt123/anaconda3/envs/mase/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=47` in the `DataLoader` to improve performance.
/home/bkt123/anaconda3/envs/mase/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=47` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=1` reached.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified



  | Name      | Type               | Params
-------------------------------------------------
0 | model     | GraphModule        | 14.0 M
1 | loss_fn   | CrossEntropyLoss   | 0     
2 | acc_train | MulticlassAccuracy | 0     
3 | acc_val   | MulticlassAccuracy | 0     
4 | acc_test  | MulticlassAccuracy | 0     
5 | loss_val  | MeanMetric         | 0     
6 | loss_test | MeanMetric         | 0     
-------------------------------------------------
14.0 M    Trainable params
0         Non-trainable params
14.0 M    Total params
56.118    Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

In [None]:
for node in mg_lt.fx_graph.nodes:
    if get_mase_op(node) in ['linear', "conv2d", "conv1d"]:
        print(node.name)
        total_w = 0
        pruned_w = 0
        w = mg_lt.modules[node.target].weight
        for s in w:
            total_w += s.numel()
            pruned_w += s.numel() - s[s.nonzero()].numel()

        pruned_percent = pruned_w / total_w
        print(f"Pruned percent: {pruned_percent}")
        print(75*'-')