In [1]:
import sys
import logging
import os
from pathlib import Path
from pprint import pprint as pp

# figure out the correct path
machop_path = Path(".").resolve().parent.parent /"machop"
assert machop_path.exists(), "Failed to find machop at: {}".format(machop_path)
sys.path.append(str(machop_path))

from chop.tools.checkpoint_load import load_model
from chop.dataset import MaseDataModule, get_dataset_info
from chop.tools.logger import set_logging_verbosity

from chop.passes.graph.analysis import (
    report_node_meta_param_analysis_pass,
    profile_statistics_analysis_pass,
)
from chop.passes.graph import (
    add_common_metadata_analysis_pass,
    init_metadata_analysis_pass,
    add_software_metadata_analysis_pass,
)
from chop.tools.get_input import InputGenerator
from chop.ir.graph.mase_graph import MaseGraph

from chop.models import get_model_info, get_model

set_logging_verbosity("info")


[32mINFO    [0m [34mSet logging level to info[0m


In [18]:
batch_size = 512
model_name = "vgg7"
dataset_name = "cifar10"

data_module = MaseDataModule(
    name=dataset_name,
    batch_size=batch_size,
    model_name=model_name,
    num_workers=0,
    # custom_dataset_cache_path="../../chop/dataset"
)
data_module.prepare_data()
data_module.setup()

dataset_info = get_dataset_info(dataset_name)

model_info = get_model_info(model_name)
model = get_model(
    model_name,
    task="cls",
    dataset_info=dataset_info,
    pretrained=False,
    checkpoint = None)


# LAB1_CUSTOM_PATH = "/home/bkt123/dev/advanced-deep-learning-systems/mase/mase_output/lab-1_jsc-custom/software/training_ckpts/best.ckpt"
# model = load_model(load_name=LAB1_CUSTOM_PATH, load_type="pl", model=model)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [19]:
from chop.actions import test, train
import torch

# print(isinstance(mg.model, torch.nn.Module))

task = "cls"

save_path = "./vgg-uncompressed/"

train_params = {
    "model": model,
    "model_info": model_info,
    "data_module": data_module,
    "dataset_info": dataset_info,
    "task": task,
    "optimizer": "adam",
    "learning_rate": 3e-3,
    "weight_decay": 0,
    "plt_trainer_args": {
        "max_epochs": 1,
    }, 
    "auto_requeue": False,
    "save_path": None,
    "visualizer": None,
    "load_name": None,
    "load_type": None
}


In [None]:

train_params["save_path"] = save_path

train(**train_params)

test(**train_params)

In [20]:
from chop.passes.graph.interface.save_and_load import save_mase_graph_interface_pass

mg = MaseGraph(model)

input_generator = InputGenerator(
    data_module=data_module,
    model_info=model_info,
    task="cls",
    which_dataloader="train",
    max_batches=1
)
dummy_in = next(iter(input_generator))

mg, _ = init_metadata_analysis_pass(mg, None)
mg, _ = add_common_metadata_analysis_pass(mg, {"dummy_in": dummy_in})
mg, _ = add_software_metadata_analysis_pass(mg, None)

torch.save(mg.model.state_dict(), f'{save_path}/model_weights.pth')


In [9]:
from chop.actions.optimize.prune import prune_iterative
from lightning.pytorch.loggers.tensorboard import TensorBoardLogger

import torch
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(log_dir='./project/vgg-cifar/tensorboard')

pass_args = {
    "iterative_prune": {
        "num_iterations": 1,
        "scope": "global",
        "granularity": "elementwise",
        "method": "l1-norm",
        "sparsity": 0.5
    },
    "train": {
        "name": "accuracy",
        "data_loader": "train_dataloader",
        "num_samples": 1000,
        "max_epochs": 1,
        "lr_scheduler": "linear",
        "optimizer": "adam",
        "learning_rate": 1e-3,
        "num_warmup_steps": 0,
    }
}

model = get_model(
    model_name,
    task="cls",
    dataset_info=dataset_info,
    pretrained=False,
    checkpoint = None
)


model, mg, results = prune_iterative(
    model,
    model_info,
    "cls",
    dataset_info,
    data_module,
    {"prune": pass_args},
    visualizer=writer
)


Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


  0%|          | 0/9 [00:00<?, ?it/s]
100%|██████████| 29/29 [00:00<00:00, 1932.83it/s]
100%|██████████| 1/1 [00:11<00:00, 11.35s/it]
  0%|          | 0/9 [01:01<?, ?it/s]


In [10]:
quantize_args = {
    "by": "name",
    "default": {
        "config": {
            "name": "integer",
            # data
            "data_in_width": 8,
            "data_in_frac_width": 5,
            # weight
            "weight_width": 8,
            "weight_frac_width": 5,
            # bias
            "bias_width": 8,
            "bias_frac_width": 5,
        }
    }
}

train_args = {
    "name": "accuracy",
    "data_loader": "train_dataloader",
    "num_samples": 1000,
    "max_epochs": 1,
    "lr_scheduler": "linear",
    "optimizer": "adam",
    "learning_rate": 3e-3,
    "num_warmup_steps": 0,
}

In [11]:
from chop.actions import quantize_model

config = {
    "quantization": {
        "quantization_config": quantize_args,
        "train": train_args,
    }
}

model, mg, results = quantize_model(
    model,
    model_info,
    "cls",
    dataset_info,
    data_module,
    config,
)

print(results)

train_params["model"] = model

test(**train_params)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


{'loss': 2.3024988174438477, 'accuracy': 0.10000000149011612}
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


/home/bkt123/anaconda3/envs/mase/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=47` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     test_acc_epoch         0.10000000149011612
     test_loss_epoch        2.4431118965148926
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


In [12]:
from collections import Counter
import heapq
import json
import bitarray, os

from chop.passes.graph.utils import get_node_actual_target

def huffman_encode(freqs):
    heap = [[weight, [symbol, ""]] for symbol, weight in freqs.items()]
    heapq.heapify(heap)
    while len(heap) > 1:
        lo = heapq.heappop(heap)
        hi = heapq.heappop(heap)
        for pair in lo[1:]:
            pair[1] = '0' + pair[1]
        for pair in hi[1:]:
            pair[1] = '1' + pair[1]
        heapq.heappush(heap, [lo[0] + hi[0]] + lo[1:] + hi[1:])
    return dict(sorted(heapq.heappop(heap)[1:], key=lambda p: (len(p[-1]), p)))

def find_module_of_parameter(model, full_param_name):
    """
    Find the module that contains the specified parameter.
    
    Parameters:
    - model: The PyTorch model to search within.
    - full_param_name: The full name of the parameter (including module path).
    
    Returns:
    - The module containing the parameter, or None if not found.
    """
    # Split the full parameter name into parts
    parts = full_param_name.split('.')
    submodule_path = parts[:-1]  # Everything except the last part, which is the parameter name
    
    # Start with the base model
    current_module = model
    
    # Traverse the modules according to the path
    for submodule_name in submodule_path:
        # Update the current_module to go deeper
        if hasattr(current_module, submodule_name):
            current_module = getattr(current_module, submodule_name)
        else:
            # Return None if any part of the path doesn't exist
            return None
    
    return current_module

def flatten_parameters(model, mg):
    """Flatten and concatenate all model parameters into a list."""
    named_params = list(model.named_parameters())

    get_named_params = lambda name, params: [param for param in params if name in param[0]]

    param_list = []
    for node in mg.fx_graph.nodes:
        module = get_node_actual_target(node)
        if node.target in mg.modules:
            named_params_node = get_named_params(f"{node.target}.", named_params)      
            for name, param in module.named_parameters():
                actual_name, actual_param = get_named_params(name, named_params_node)[0]
                actual_module = find_module_of_parameter(model, actual_name)
                if isinstance(actual_module, torch.nn.utils.parametrize.ParametrizationList) and "weight" in actual_name:
                    actual_param = actual_param * actual_module[0].mask.to(actual_param.dtype)
                if hasattr(module, "w_quantizer"):
                    actual_param = module.w_quantizer(actual_param)
                if hasattr(module, "b_quantizer"):
                    actual_param = module.b_quantizer(actual_param)
                param_list.extend(actual_param.flatten().tolist())


    
    return param_list


def huffman_encode_pass(mg, pass_args):
    model = pass_args['model']

    save_dir = pass_args['save_dir']

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
        os.makedirs(os.path.join(save_dir, "parameters"))
        os.makedirs(os.path.join(save_dir, "masks"))

    params = flatten_parameters(model, mg)
    freqs = Counter(params)

    global_huffman_codes = huffman_encode(freqs)

    huffman_codes_path = os.path.join(save_dir, "global_huffman_codes.json")
    with open(huffman_codes_path, 'w') as file:
        json.dump(global_huffman_codes, file)
    
    named_params = list(model.named_parameters())

    get_named_params = lambda name, params: [param for param in params if name in param[0]]
    for node in mg.fx_graph.nodes:
        module = get_node_actual_target(node)
        if node.target in mg.modules:
            named_params_node = get_named_params(f"{node.target}.", named_params)      
            for name, param in module.named_parameters():
                actual_name, actual_param = get_named_params(name, named_params_node)[0]
                actual_module = find_module_of_parameter(model, actual_name)
                if isinstance(actual_module, torch.nn.utils.parametrize.ParametrizationList) and "weight" in actual_name:
                    actual_param = actual_param * actual_module[0].mask.to(actual_param.dtype)
                if hasattr(module, "w_quantizer"):
                    actual_param = module.w_quantizer(actual_param)
                if hasattr(module, "b_quantizer"):
                    actual_param = module.b_quantizer(actual_param)
                
                actual_param_values = actual_param.flatten().tolist()
                # Encode using global Huffman codes
                encoded_values = ''.join([global_huffman_codes[val] for val in actual_param_values])
                ba = bitarray.bitarray(encoded_values)
                
                # Save encoded parameters
                # encoded_file_path = os.path.join(save_dir, "parameters", f"{actual_name}.bin")
                encoded_file_path = os.path.join(save_dir, "parameters", f"{node.target}.{name}.{len(ba)}.bin")

                with open(encoded_file_path, 'wb') as encoded_file:
                    ba.tofile(encoded_file)


    # for name, buffer in model.named_buffers():
    #     # Assuming buffer is already quantized
    #     if not "mask" in name:
    #         continue

    #     mask = buffer.flatten().tolist()

    #     binary_code = {True: '1', False: '0'}
        
    #     # Encode using global Huffman codes
    #     encoded_values = ''.join([binary_code[val] for val in mask])
    #     ba = bitarray.bitarray(encoded_values)
        
    #     # Save encoded parameters
    #     encoded_file_path = os.path.join(save_dir, "masks", f"{name}.bin")
    #     with open(encoded_file_path, 'wb') as encoded_file:
    #         ba.tofile(encoded_file)

    return mg, _

In [13]:
mg, _ = huffman_encode_pass(mg, {"save_dir": "./vgg-compressed", "model": model})

In [14]:
def decode_huffman(encoded_data, huffman_codes):
    reverse_huffman_codes = {v: k for k, v in huffman_codes.items()}
    decoded_data = []
    code = ""
    for bit in encoded_data:
        code += str(bit)
        if code in reverse_huffman_codes:
            decoded_data.append(float(reverse_huffman_codes[code]))
            code = ""
    return decoded_data

def load_huffman_encoded_model(mg, pass_args):
    load_dir = pass_args['load_dir']
    dummy_in = pass_args['dummy_in']

    model = mg.model

    with open(f"{load_dir}/global_huffman_codes.json", 'r') as file:
        huffman_codes = json.load(file)

    encoded_parameter_files = os.listdir(os.path.join(load_dir, "parameters"))

    for name, param in model.named_parameters():
        file_name = None
        for file in encoded_parameter_files:
            if file.startswith(name):
                file_name = file
                break
        
        if file_name is None:
            continue

        encoded_file_path = os.path.join(load_dir, "parameters", file_name)
        num_bits = int(file_name.split('.')[-2])

        with open(encoded_file_path, 'rb') as encoded_file:
            encoded_data = bitarray.bitarray()
            encoded_data.fromfile(encoded_file)
            encoded_data = encoded_data[:num_bits]
            decoded_data = decode_huffman(encoded_data, huffman_codes)
            decoded_data = torch.tensor(decoded_data)
            decoded_data = decoded_data.view(param.shape)
            param.data = decoded_data

    mg = MaseGraph(model)

    mg, _ = init_metadata_analysis_pass(mg, None)
    mg, _ = add_common_metadata_analysis_pass(mg, {"dummy_in": dummy_in})
    mg, _ = add_software_metadata_analysis_pass(mg, None)

    return mg, {}

In [15]:
model_rebuilt = get_model(
    model_name,
    task="cls",
    dataset_info=dataset_info,
    pretrained=False,
    checkpoint = None
)

input_generator = InputGenerator(
    data_module=data_module,
    model_info=model_info,
    task="cls",
    which_dataloader="train",
    max_batches=1
)
dummy_in = next(iter(input_generator))

mg_r = MaseGraph(model_rebuilt)

mg_r, _ = init_metadata_analysis_pass(mg_r, None)
mg_r, _ = add_common_metadata_analysis_pass(mg_r, {"dummy_in": dummy_in})
mg_r, _ = add_software_metadata_analysis_pass(mg_r, None)

mg_r, _ = load_huffman_encoded_model(mg_r, {"load_dir": "./vgg-compressed", "dummy_in": dummy_in})

model_rebuilt = mg_r.model

print(model_rebuilt.state_dict())

train_params["model"] = model_rebuilt

test(**train_params)


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


OrderedDict([('feature_layers.0.weight', tensor([[[[ 0.0312,  0.1562,  0.0625],
          [-0.0312, -0.1250, -0.0312],
          [-0.0312,  0.1250, -0.0938]],

         [[ 0.1875,  0.0938, -0.0938],
          [-0.0938, -0.1875,  0.0312],
          [ 0.1562,  0.0938, -0.1562]],

         [[-0.0938, -0.0312,  0.1562],
          [ 0.0312, -0.0625, -0.0625],
          [-0.0312,  0.1875,  0.0000]]],


        [[[-0.0938,  0.1250,  0.0938],
          [ 0.1562, -0.0312, -0.1562],
          [ 0.1875, -0.1250,  0.0000]],

         [[ 0.0000,  0.1250,  0.0312],
          [-0.0938, -0.0625, -0.0938],
          [ 0.1250, -0.0312,  0.0312]],

         [[ 0.0625, -0.1250, -0.1562],
          [-0.0625,  0.0312,  0.1250],
          [-0.0625, -0.1875,  0.1562]]],


        [[[-0.0312,  0.0000,  0.1250],
          [ 0.0938,  0.1250, -0.0312],
          [ 0.1250,  0.0312, -0.1250]],

         [[-0.1562,  0.0625,  0.0625],
          [ 0.0000,  0.1562,  0.1562],
          [-0.0938, -0.1562, -0.0938]],

   

Testing: |          | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     test_acc_epoch         0.10000000149011612
     test_loss_epoch         2.302914619445801
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
