In [1]:
import sys
from pathlib import Path

# figure out the correct path
machop_path = Path(".").resolve().parent.parent /"machop"
assert machop_path.exists(), "Failed to find machop at: {}".format(machop_path)
sys.path.append(str(machop_path))

from chop.dataset import MaseDataModule, get_dataset_info
from chop.tools.logger import set_logging_verbosity

from chop.passes.graph import (
    add_common_metadata_analysis_pass,
    init_metadata_analysis_pass,
    add_software_metadata_analysis_pass,
)
from chop.tools.get_input import InputGenerator
from chop.ir.graph.mase_graph import MaseGraph

from chop.models import get_model_info, get_model

set_logging_verbosity("info")


[32mINFO    [0m [34mSet logging level to info[0m


In [3]:
batch_size = 256
model_name = "jsc-custom"
dataset_name = "jsc"

data_module = MaseDataModule(
    name=dataset_name,
    batch_size=batch_size,
    model_name=model_name,
    num_workers=16,
)

data_module.prepare_data()
data_module.setup()

model_info = get_model_info(model_name)
dataset_info = get_dataset_info(dataset_name)

input_generator = InputGenerator(
    data_module=data_module,
    model_info=model_info,
    task="cls",
    which_dataloader="train",
)


In [4]:
model_base = get_model(
    model_name,
    task="cls",
    dataset_info=dataset_info,
    pretrained=False
)

model_oneshot = get_model(
    model_name,
    task="cls",
    dataset_info=dataset_info,
    pretrained=False
)

input_generator = InputGenerator(
    data_module=data_module,
    model_info=model_info,
    task="cls",
    which_dataloader="train",
    max_batches=1
)

dummy_in = next(iter(input_generator))
# generate the mase graph and initialize node metadata
mg_base = MaseGraph(model=model_base)
mg_oneshot = MaseGraph(model=model_oneshot)

mg_base, _ = init_metadata_analysis_pass(mg_base, None)
mg_base, _ = add_common_metadata_analysis_pass(mg_base, {"dummy_in": dummy_in})
mg_base, _ = add_software_metadata_analysis_pass(mg_base, None)

mg_oneshot, _ = init_metadata_analysis_pass(mg_oneshot, None)
mg_oneshot, _ = add_common_metadata_analysis_pass(mg_oneshot, {"dummy_in": dummy_in})
mg_oneshot, _ = add_software_metadata_analysis_pass(mg_oneshot, None)

Training the JSC network and then pruning it to see final performance

In [5]:
import copy
from lightning.pytorch.loggers.tensorboard import TensorBoardLogger


default_train_params = {
    "model_info": model_info,
    "data_module": data_module,
    "dataset_info": dataset_info,
    "task": "cls",
    "optimizer": "adam",
    "learning_rate": 1e-3,
    "weight_decay": 0,
    "plt_trainer_args": {
        "max_epochs": 10,
    }, 
    "auto_requeue": False,
    "save_path": None,
    "visualizer": None,
    "load_name": None,
    "load_type": None
}

base_train_params = copy.deepcopy(default_train_params)
base_train_params["model"] = mg_base.model
base_train_params["plt_trainer_args"]["max_epochs"] = 5

visualizer = TensorBoardLogger(
    save_dir="./project/vgg-cifar/tensorboard"
)
base_train_params["visualizer"] = visualizer

oneshot_train_params_1 = copy.deepcopy(default_train_params)
oneshot_train_params_1["model"] = mg_oneshot.model
oneshot_train_params_1["plt_trainer_args"]["max_epochs"] = 3

oneshot_train_params_2 = copy.deepcopy(default_train_params)
oneshot_train_params_2["model"] = mg_oneshot.model
oneshot_train_params_2["plt_trainer_args"]["max_epochs"] = 2

Base model:

In [6]:
from chop.actions import test, train

train(**base_train_params)
test(**base_train_params)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name      | Type               | Params
-------------------------------------------------
0 | model     | GraphModule        | 1.3 K 
1 | loss_fn   | CrossEntropyLoss   | 0     
2 | acc_train | MulticlassAccuracy | 0     
3 | acc_val   | MulticlassAccuracy | 0     
4 | acc_test  | MulticlassAccuracy | 0     
5 | loss_val  | MeanMetric         | 0     
6 | loss_test | MeanMetric         | 0     
-------------------------------------------------
1.3 K     Trainable params
0         Non-trainable params
1.3 K     Total params
0.005     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=5` reached.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Testing: |          | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     test_acc_epoch         0.7335353493690491
     test_loss_epoch        0.7563168406486511
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


Base Model after pruning:

In [7]:
from chop.passes.graph.transforms import (
    prune_transform_pass,
)

pass_args = {
    "weight":{
        "scope" : "global",
        "granularity" : "elementwise",
        "method" :  "l1-norm",
        "sparsity" : 0.7,},
    "activation":{
        "scope" : "global",
        "granularity" : "elementwise",
        "method" : "l1-norm",
        "sparsity" : 0.7,
    },
}
 
mg_base, _ = prune_transform_pass(mg_base, pass_args)

base_train_params["model"] = mg_base.model
test(**base_train_params)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Testing: |          | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     test_acc_epoch         0.32749971747398376
     test_loss_epoch         1.498740792274475
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


One shot pruning:

In [8]:
from chop.actions import test, train
from chop.passes.graph.transforms import (
    prune_transform_pass,
)

train(**oneshot_train_params_1)
pass_args = {
    "weight":{
        "scope" : "global",
        "granularity" : "elementwise",
        "method" :  "l1-norm",
        "sparsity" : 0.7,},
    "activation":{
        "scope" : "global",
        "granularity" : "elementwise",
        "method" : "l1-norm",
        "sparsity" : 0.7,
    },
}

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs



  | Name      | Type               | Params
-------------------------------------------------
0 | model     | GraphModule        | 1.3 K 
1 | loss_fn   | CrossEntropyLoss   | 0     
2 | acc_train | MulticlassAccuracy | 0     
3 | acc_val   | MulticlassAccuracy | 0     
4 | acc_test  | MulticlassAccuracy | 0     
5 | loss_val  | MeanMetric         | 0     
6 | loss_test | MeanMetric         | 0     
-------------------------------------------------
1.3 K     Trainable params
0         Non-trainable params
1.3 K     Total params
0.005     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=3` reached.


In [9]:
mg_oneshot.model = oneshot_train_params_1["model"]
mg_oneshot, _ = prune_transform_pass(mg_oneshot, pass_args)

oneshot_train_params_2["model"] = mg_oneshot.model
train(**oneshot_train_params_2)
test(**oneshot_train_params_2)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs



  | Name      | Type               | Params
-------------------------------------------------
0 | model     | GraphModule        | 1.3 K 
1 | loss_fn   | CrossEntropyLoss   | 0     
2 | acc_train | MulticlassAccuracy | 0     
3 | acc_val   | MulticlassAccuracy | 0     
4 | acc_test  | MulticlassAccuracy | 0     
5 | loss_val  | MeanMetric         | 0     
6 | loss_test | MeanMetric         | 0     
-------------------------------------------------
1.3 K     Trainable params
0         Non-trainable params
1.3 K     Total params
0.005     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=2` reached.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Testing: |          | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     test_acc_epoch         0.5370740294456482
     test_loss_epoch         1.142966389656067
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


Iterative Pruning

In [13]:
data_module = MaseDataModule(
    name=dataset_name,
    batch_size=batch_size,
    model_name=model_name,
    num_workers=16,
)

data_module.prepare_data()
data_module.setup()

model_info = get_model_info(model_name)
dataset_info = get_dataset_info(dataset_name)

input_generator = InputGenerator(
    data_module=data_module,
    model_info=model_info,
    task="cls",
    which_dataloader="train",
)

model_lt = get_model(
    model_name,
    task="cls",
    dataset_info=data_module.dataset_info,
    pretrained=False
)

In [14]:
from chop.actions.optimize.prune import prune_iterative
from lightning.pytorch.loggers.tensorboard import TensorBoardLogger

import torch
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(log_dir='./project/vgg-cifar/tensorboard')

pass_args = {
    "iterative_prune": {
        "num_iterations": 1,
        "scope": "global",
        "granularity": "elementwise",
        "method": "l1-norm",
        "sparsity": 0.7
    },
    "train": {
        "name": "accuracy",
        "data_loader": "train_dataloader",
        "num_samples": 10000,
        "max_epochs": 10,
        "lr_scheduler": "linear",
        "optimizer": "adam",
        "learning_rate": 1e-3,
        "num_warmup_steps": 0,
    }
}


model, mg, results = prune_iterative(
    model_lt,
    model_info,
    "cls",
    dataset_info,
    data_module,
    {"prune": pass_args},
    visualizer=writer
)


  0%|          | 0/3 [00:00<?, ?it/s]
100%|██████████| 15/15 [00:00<00:00, 15627.06it/s]
100%|██████████| 1/1 [00:01<00:00,  1.21s/it]
  0%|          | 0/3 [00:02<?, ?it/s]


In [None]:
import copy
from chop.actions import test, train

lt_train_params = copy.deepcopy(default_train_params)
lt_train_params["model"] = model

test(**lt_train_params)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


Testing: |          | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     test_acc_epoch         0.26980000734329224
     test_loss_epoch        1.9204679727554321
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


In [None]:
from chop.passes.graph.utils import get_mase_op

last = []

mg_lt = MaseGraph(model=model)
mg_lt, _ = init_metadata_analysis_pass(mg_lt, None)
mg_lt, _ = add_common_metadata_analysis_pass(mg_lt, {"dummy_in": dummy_in})
mg_lt, _ = add_software_metadata_analysis_pass(mg_lt, None)

for node in mg_lt.fx_graph.nodes:
    if get_mase_op(node) in ['linear', "conv2d", "conv1d"]:
        print(node.name)
        total_w = 0
        pruned_w = 0
        module = mg_lt.modules[node.target]
        # print(module.weight.numel())
        w = module.weight
        # print(w.shape)
        # print("Num of masks: ", len(mg_lt.modules[node.target].parametrizations['weight']))
        # if node.name == "block_1_0":
        #     for mask in mg_lt.modules[node.target].parametrizations['weight']:
        #         last.append(mask._buffers['mask'].detach().numpy())
        #         print(mask._buffers['mask'].detach().numpy())

        flat_w = w.flatten()
        total_w += flat_w.numel()
        pruned_w += flat_w.numel() - flat_w[flat_w != 0].numel()

        pruned_percent = pruned_w / total_w
        print(f"Pruned percent: {pruned_percent}")
        print(75*'-')

block_1_0
216
Num of masks:  5
Pruned percent: 0.39814814814814814
---------------------------------------------------------------------------
block_2_0
1152
Num of masks:  5
Pruned percent: 0.6675347222222222
---------------------------------------------------------------------------
block_3_0
1536
Num of masks:  5
Pruned percent: 0.5696614583333334
---------------------------------------------------------------------------
block_4_0
6144
Num of masks:  5
Pruned percent: 0.7576497395833334
---------------------------------------------------------------------------
linear
640
Num of masks:  5
Pruned percent: 0.61875
---------------------------------------------------------------------------


Iterative Pruning more iterations

In [None]:
data_module = MaseDataModule(
    name=dataset_name,
    batch_size=batch_size,
    model_name=model_name,
    num_workers=16,
)

data_module.prepare_data()
data_module.setup()

model_info = get_model_info(model_name)
dataset_info = get_dataset_info(dataset_name)

input_generator = InputGenerator(
    data_module=data_module,
    model_info=model_info,
    task="cls",
    which_dataloader="train",
)

model_lt = get_model(
    model_name,
    task="cls",
    dataset_info=data_module.dataset_info,
    pretrained=False
)

In [None]:
from chop.actions.optimize.prune import prune_iterative
from lightning.pytorch.loggers.tensorboard import TensorBoardLogger

import torch
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(log_dir='./project/vgg-cifar/tensorboard')

pass_args = {
    "iterative_prune": {
        "num_iterations": 5,
        "scope": "global",
        "granularity": "elementwise",
        "method": "l1-norm",
        "sparsity": 0.7
    },
    "train": {
        "name": "accuracy",
        "data_loader": "train_dataloader",
        "num_samples": 10000,
        "max_epochs": 10,
        "lr_scheduler": "linear",
        "optimizer": "adam",
        "learning_rate": 1e-3,
        "num_warmup_steps": 0,
    }
}


model, results = prune_iterative(
    model_lt,
    model_info,
    "cls",
    dataset_info,
    data_module,
    {"prune": pass_args},
    visualizer=writer
)


In [None]:
lt_train_params = copy.deepcopy(default_train_params)
lt_train_params["model"] = model

test(**lt_train_params)

In [None]:
from chop.passes.graph.utils import get_mase_op

last = []

mg_lt = MaseGraph(model=model)
mg_lt, _ = init_metadata_analysis_pass(mg_lt, None)
mg_lt, _ = add_common_metadata_analysis_pass(mg_lt, {"dummy_in": dummy_in})
mg_lt, _ = add_software_metadata_analysis_pass(mg_lt, None)

for node in mg_lt.fx_graph.nodes:
    if get_mase_op(node) in ['linear', "conv2d", "conv1d"]:
        print(node.name)
        total_w = 0
        pruned_w = 0
        module = mg_lt.modules[node.target]
        # print(module.weight.numel())
        w = module.weight
        # print(w.shape)
        # print("Num of masks: ", len(mg_lt.modules[node.target].parametrizations['weight']))
        # if node.name == "block_1_0":
        #     for mask in mg_lt.modules[node.target].parametrizations['weight']:
        #         last.append(mask._buffers['mask'].detach().numpy())
        #         print(mask._buffers['mask'].detach().numpy())

        flat_w = w.flatten()
        total_w += flat_w.numel()
        pruned_w += flat_w.numel() - flat_w[flat_w != 0].numel()

        pruned_percent = pruned_w / total_w
        print(f"Pruned percent: {pruned_percent}")
        print(75*'-')