## Loading Libraries

In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import tensor
from torch.optim import Adam
from torch.nn import Linear, Sequential, ReLU, Identity, BatchNorm1d as BN
from torch_geometric.datasets import TUDataset
from torch_geometric.utils import degree
from torch_geometric.nn import global_mean_pool
from torch_geometric.data import DataLoader, DenseDataLoader as DenseLoader
from torch_geometric.data import Batch


from sklearn.model_selection import StratifiedKFold
from collections import OrderedDict
from tqdm import tqdm
import argparse
import statistics as stat
from tabulate import tabulate
import statistics as stat
import time
import os
import numpy as np
import random




In [10]:

#Quantization
from dq.quantization import IntegerQuantizer
from dq.linear_quantized import LinearQuantized
from dq.baseline_quant import GINConvQuant
from dq.multi_quant import evaluate_prob_mask, GINConvMultiQuant
from dq.transforms import ProbabilisticHighDegreeMask

#loading dataset and training
from dataset import get_dataset
from train_eval import cross_validation_with_val_set
from gin import GIN
import utils as utils

# output dir and tensorboard writer
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
from pathlib import Path

# Computing Energy and cpu usage 
import psutil
import itertools
import tracemalloc
import gc


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Setting Arguments 

In [11]:

parser = argparse.ArgumentParser()
parser.add_argument("--model",type=str,default='GIN')
parser.add_argument("--epochs", type=int, default=200)
parser.add_argument("--batch_size", type=int, default=128)
parser.add_argument("--num_layers", type=int, default=5)
parser.add_argument("--hidden", type=int, default=64)
parser.add_argument("--lr", type=float, default=0.01)
parser.add_argument("--wd", type=float, default=4e-5)
parser.add_argument("--noise", type=float, default=1.0)
parser.add_argument("--lr_decay_factor", type=float, default=0.5)
parser.add_argument("--lr_decay_step_size", type=int, default=50)

parser.add_argument("--path", type=str, default="/datasets/", help="where all datasets live")
parser.add_argument("--outdir", type=str, default="D:/output/BBBPBINexps/INT8-DQ")

parser.add_argument("--DQ", action="store_true", help="enables DegreeQuant")
parser.add_argument("--low", type=float, default=0.0)
parser.add_argument("--change", type=float, default=0.1)
parser.add_argument("--sample_prop", type=float, default=None)

parser.add_argument("--result_folder",type=str,default='result')
# Path to checkpoint
parser.add_argument("--check_folder",type=str,default='checkpoint')
# Path to dataset
parser.add_argument("--path2dataset",type=str,default='/')

quant_mode = parser.add_mutually_exclusive_group(required=False)
quant_mode.add_argument("--fp32", action="store_true", help="no quantization")
quant_mode.add_argument("--int8", action="store_true", help="INT8 quant")
quant_mode.add_argument("--int4", action="store_true", help="INT4 quant")

ste_mode = parser.add_mutually_exclusive_group(required=False)
ste_mode.add_argument("--ste_abs", action="store_true", help="STE-ABS")
ste_mode.add_argument("--ste_mom", action="store_true", help="STE-MOM")
ste_mode.add_argument("--ste_per", action="store_true", help="STE-PER")
ste_mode.add_argument("--gc_abs", action="store_true", help="GC-ABS")
ste_mode.add_argument("--gc_mom", action="store_true", help="GC-MOM")
ste_mode.add_argument("--gc_per", action="store_true", help="GC-PER")




args = parser.parse_args(['--fp32', '--ste_abs'])



### Generating the qConfig

- INT4=True $\Rightarrow$ args.int4=True
- DQ=True $\Rightarrow$ args.DQ=True
- gc-per=True $\Rightarrow$ args.gc_per=True

In [12]:
args.DQ=True
args.fp32=False
args.int4=True
args.int8=False


args.ste_abs=False
args.ste_mom=False
args.ste_per=True
args.gc_abs=False
args.gc_mom=False
args.gc_per=False


if args.fp32:
    qypte = "FP32"
elif args.int8:
    qypte = "INT8"
elif args.int4:
    qypte = "INT4"
else:
    raise NotImplementedError

ste = False
momentum = False
percentile = None

# ste quant
if args.ste_abs:
    ste = True
elif args.ste_mom:
    ste = True
    momentum = True
elif args.gc_abs:
    pass
elif args.gc_mom:
    momentum = True
elif args.ste_per:
    ste = True
    percentile = 0.01 if args.int4 else 0.001
elif args.gc_per:
    percentile = 0.01 if args.int4 else 0.001
else:
    raise NotImplementedError



if args.DQ:
    DQ = {"prob_mask_low": args.low, "prob_mask_change": args.change}

print(args)

Namespace(model='GIN', epochs=200, batch_size=128, num_layers=5, hidden=64, lr=0.01, wd=4e-05, noise=1.0, lr_decay_factor=0.5, lr_decay_step_size=50, path='/datasets/', outdir='D:/output/BBBPBINexps/INT8-DQ', DQ=True, low=0.0, change=0.1, sample_prop=None, result_folder='result', check_folder='checkpoint', path2dataset='/', fp32=False, int8=False, int4=True, ste_abs=False, ste_mom=False, ste_per=True, gc_abs=False, gc_mom=False, gc_per=False)


## Loading dataset and Transformation

In [13]:
from torch_geometric.datasets import MoleculeNet
from torch_geometric.utils import dense_to_sparse
from torch.utils.data import random_split, Subset
from torch_geometric.data import Data, InMemoryDataset, DataLoader

def load_MolecueNet(dataset_dir, dataset_name, task=None):
    """ Attention the multi-task problems not solved yet """
    molecule_net_dataset_names = {name.lower(): name for name in MoleculeNet.names.keys()}
    dataset = MoleculeNet(root=dataset_dir, name=molecule_net_dataset_names[dataset_name.lower()])
    dataset.data.x = dataset.data.x.float()
    if task is None:
        dataset.data.y = dataset.data.y.squeeze().long()
    else:
        dataset.data.y = dataset.data.y[task].long()
    dataset.node_type_dict = None
    dataset.node_color = None
    return dataset

In [14]:

import torch
from torch_geometric.data import Batch
from torch_geometric.utils import degree

class ProbabilisticHighDegreeMask:
    def __init__(self, low_quantise_prob, high_quantise_prob, per_graph=True):
        self.low_prob = low_quantise_prob
        self.high_prob = high_quantise_prob
        self.per_graph = per_graph

    def _process_graph(self, graph):
        if graph.num_nodes == 0 or graph.edge_index.size(1) == 0:  # Check for empty graphs
            graph.prob_mask = torch.zeros(graph.num_nodes, dtype=torch.float)
            return graph

        n = graph.num_nodes
        indegree = degree(graph.edge_index[1], n, dtype=torch.long)
        counts = torch.bincount(indegree)

        step_size = (self.high_prob - self.low_prob) / n
        indegree_ps = counts * step_size
        indegree_ps = torch.cumsum(indegree_ps, dim=0)
        indegree_ps += self.low_prob
        graph.prob_mask = indegree_ps[indegree]

        return graph

    def __call__(self, data):
        if self.per_graph and isinstance(data, Batch):
            graphs = data.to_data_list()
            processed = []
            for g in graphs:
                g = self._process_graph(g)
                processed.append(g)
            return Batch.from_data_list(processed)
        else:
            return self._process_graph(data)


In [15]:
from torch_geometric.transforms import Compose
import torch_geometric.transforms as T

def get_dataset(path, name, sparse=True, cleaned=False, DQ=None):
    # Load the dataset
    molecule_net_dataset_names = {name.lower(): name for name in MoleculeNet.names.keys()}
    
    # Load the dataset with the composed transform
    dataset = MoleculeNet(
        root=path,
        name=molecule_net_dataset_names[name.lower()],
        #transform=transform
    )

   # Remove graphs with zero nodes
    filtered_data = []
    filtered_labels = []
    for data in dataset:

        if data.edge_index.numpy().size > 0 and data.num_nodes > 0:  # Keep only graphs with at least one node
            filtered_data.append(data)
            filtered_labels.append(data.y)
    
    # Replace the dataset with the filtered data
    dataset._data_list = filtered_data
    dataset._indices = range(len(filtered_data))  # Update indices
    dataset.data.y = torch.stack(filtered_labels)  # Update labels
    
    print(dataset.num_classes)
    if dataset.data.x is None:
        max_degree = 0
        degs = []
        for data in dataset:
            degs += [degree(data.edge_index[0], dtype=torch.long)]
            max_degree = max(max_degree, degs[-1].max().item())

        if max_degree < 1000:
            dataset.transform = T.OneHotDegree(max_degree)
        else:
            deg = torch.cat(degs, dim=0).to(torch.float)
            mean, std = deg.mean().item(), deg.std().item()
            dataset.transform = NormalizedDegree(mean, std)

    if not sparse:
        num_nodes = max_num_nodes = 0
        for data in dataset:
            num_nodes += data.num_nodes
            max_num_nodes = max(data.num_nodes, max_num_nodes)

        # Filter out a few really large graphs in order to apply DiffPool.
        if name == "BBBP":
            num_nodes = min(int(num_nodes / len(dataset) * 1.5), max_num_nodes)
        else:
            num_nodes = min(int(num_nodes / len(dataset) * 5), max_num_nodes)

        indices = []
        for i, data in enumerate(dataset):
            if data.num_nodes <= num_nodes:
                indices.append(i)
        dataset = dataset[torch.tensor(indices)]

        if dataset.transform is None:
            dataset.transform = T.ToDense(num_nodes)
        else:
            dataset.transform = T.Compose([dataset.transform, T.ToDense(num_nodes)])
    
 
    
    # If there are existing transforms, add them
    if DQ is not None:
        print(f"Generating ProbabilisticHighDegreeMask: {DQ}")
        dq_transform = ProbabilisticHighDegreeMask(
            DQ["prob_mask_low"], min(DQ["prob_mask_low"] + DQ["prob_mask_change"], 1.0)
        )

    if dataset.transform is None:
        dataset.transform = dq_transform
    else:
        dataset.transform = T.Compose([dataset.transform, dq_transform])

    return dataset      
    #if dataset.transform is None:
        #print(f"dataset.transform:{dataset.transform}")
  
        #dataset.transform = T.Compose([dataset.transform, dq_transform])
    

    return dataset

In [16]:
dataset_name='BBBP'
dataset = get_dataset(args.path, dataset_name, sparse=True, DQ=DQ)
print(dataset.num_classes)



2
Generating ProbabilisticHighDegreeMask: {'prob_mask_low': 0.0, 'prob_mask_change': 0.1}
2


###  Output dir and tensorboard writer

In [17]:
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
from pathlib import Path


def append_date_and_time_to_string(string):
    now = datetime.utcnow().strftime("%m_%d_%H_%M_%S")

    return Path(string) / now


def set_outputdir_and_writer(
    model_name,
    outdir,
    num_layers,
    hidden,
    lr,
    quant_mode,
    ste,
    momentum,
    percentile,
    is_DQ,
    w_decay,
    low,
    change,
):

    layers = "layers_" + str(num_layers)
    hidden = "hidden_" + str(hidden)

    ste_config = "STE_" if ste else "GC_"
    if momentum:
        ste_config += "MOM"
    elif percentile is not None:
        ste_config += "PER"
    else:
        ste_config += "ABS"

    if is_DQ:
        quant_mode += "_DQ_low" + str(low) + "_chng" + str(change)

    dir = (
        Path(outdir)
        / model_name)

    dir = append_date_and_time_to_string(dir)

    writer = SummaryWriter(dir)
    print(f"Output dir:{dir}")

    return dir, writer


In [18]:
###############################################################
#model = args.model
#dataset_name = args.dataset_name
#num_layers = args.num_layers
#hidden_units=args.hidden_units
#bit=args.bit
#max_epoch = args.max_epoch
#resume = args.resume
path2result = args.result_folder+'/'+'_'+dataset_name
path2check = args.check_folder+'/'+args.model+'_'+dataset_name
if not os.path.exists(path2result):
    os.makedirs(path2result)
if not os.path.exists(path2check):
    os.makedirs(path2check)

In [19]:

# output dir and tensorboard writer
dir, writer = utils.set_outputdir_and_writer(
    "GIN",
    args.outdir,
    args.num_layers,
    args.hidden,
    args.lr,
    qypte,
    ste,
    momentum,
    percentile,
    args.DQ,
    args.wd,
    args.low,
    args.change,
)

Output dir:D:\output\BBBPBINexps\INT8-DQ\GIN\layers_5\hidden_64\INT4_DQ_low0.0_chng0.1\STE_PER\lr_0.01\wd_4e-05\06_22_11_47_38


## qGIN Model with Quatization

In [20]:

def create_quantizer(qypte, ste, momentum, percentile, signed, sample_prop):
    if qypte == "FP32":
        return Identity
    else:
        return lambda: IntegerQuantizer(
            4 if qypte == "INT4" else 8,
            signed=signed,
            use_ste=ste,
            use_momentum=momentum,
            percentile=percentile,
            sample=sample_prop,
        )


def make_quantizers(qypte, dq, sign_input, ste, momentum, percentile, sample_prop):
    if dq:
        # GIN doesn't apply DQ to the LinearQuantize layers so we keep the 
        # default inputs, weights, features keys.
        # See NOTE in the multi_quant.py file
        layer_quantizers = {
            "inputs": create_quantizer(
                qypte, ste, momentum, percentile, sign_input, sample_prop
            ),
            "weights": create_quantizer(
                qypte, ste, momentum, percentile, True, sample_prop
            ),
            "features": create_quantizer(
                qypte, ste, momentum, percentile, True, sample_prop
            ),
        }
        mp_quantizers = {
            "message_low": create_quantizer(
                qypte, ste, momentum, percentile, True, sample_prop
            ),
            "message_high": create_quantizer(
                "FP32", ste, momentum, percentile, True, sample_prop
            ),
            "update_low": create_quantizer(
                qypte, ste, momentum, percentile, True, sample_prop
            ),
            "update_high": create_quantizer(
                "FP32", ste, momentum, percentile, True, sample_prop
            ),
            "aggregate_low": create_quantizer(
                qypte, ste, momentum, percentile, True, sample_prop
            ),
            "aggregate_high": create_quantizer(
                "FP32", ste, momentum, percentile, True, sample_prop
            ),
        }
    else:
        layer_quantizers = {
            "inputs": create_quantizer(
                qypte, ste, momentum, percentile, sign_input, sample_prop
            ),
            "weights": create_quantizer(
                qypte, ste, momentum, percentile, True, sample_prop
            ),
            "features": create_quantizer(
                qypte, ste, momentum, percentile, True, sample_prop
            ),
        }
        mp_quantizers = {
            "message": create_quantizer(
                qypte, ste, momentum, percentile, True, sample_prop
            ),
            "update_q": create_quantizer(
                qypte, ste, momentum, percentile, True, sample_prop
            ),
            "aggregate": create_quantizer(
                qypte, ste, momentum, percentile, True, sample_prop
            ),
        }
    return layer_quantizers, mp_quantizers


class ResettableSequential(Sequential):
    def reset_parameters(self):
        for child in self.children():
            if hasattr(child, "reset_parameters"):
                child.reset_parameters()


class GIN(torch.nn.Module):
    def __init__(
        self,
        dataset,
        num_layers,
        hidden,
        dq,
        qypte,
        ste,
        momentum,
        percentile,
        sample_prop,
    ):
        super(GIN, self).__init__()

        self.is_dq = dq
        gin_layer = GINConvMultiQuant if dq else GINConvQuant 

        lq, mq = make_quantizers(
            qypte,
            dq,
            False,
            ste=ste,
            momentum=momentum,
            percentile=percentile,
            sample_prop=sample_prop,
        )
        lq_signed, _ = make_quantizers(
            qypte,
            dq,
            True,
            ste=ste,
            momentum=momentum,
            percentile=percentile,
            sample_prop=sample_prop,
        )

        # NOTE: see comment in multi_quant.py on the use of 
        # "mask-aware" MLPs.
        self.conv1 = gin_layer(
            ResettableSequential(
                Linear(dataset.num_features, hidden),
                ReLU(),
                LinearQuantized(hidden, hidden, layer_quantizers=lq),
                ReLU(),
                BN(hidden),
            ),
            train_eps=True,
            mp_quantizers=mq,
        )
        self.convs = torch.nn.ModuleList()
        for i in range(num_layers - 1):
            self.convs.append(
                gin_layer(
                    ResettableSequential(
                        LinearQuantized(hidden, hidden, layer_quantizers=lq_signed),
                        ReLU(),
                        LinearQuantized(hidden, hidden, layer_quantizers=lq),
                        ReLU(),
                        BN(hidden),
                    ),
                    train_eps=True,
                    mp_quantizers=mq,
                )
            )

        self.lin1 = LinearQuantized(hidden, hidden, layer_quantizers=lq_signed)
        self.lin2 = LinearQuantized(hidden, dataset.num_classes, layer_quantizers=lq)
        

    def reset_parameters(self):
        self.conv1.reset_parameters()
        for conv in self.convs:
            conv.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, data):
        # NOTE: It is possible to use the same mask consistently or generate a 
        # new mask per layer. For other experiments we used a per-layer mask
        # We did not observe major differences but we expect the impact will
        # be layer and dataset dependent. Extensive experiments assessing the
        # difference were not run, however, due to the high cost.
         
        #if hasattr(data, "prob_mask") and data.prob_mask is not None:
            #mask = evaluate_prob_mask(data)
        #else:
            #mask = None
            
            
                
         
        mask = evaluate_prob_mask(data)
        #mask = mask.float()  # Convert to FloatTensor if needed
        #print(mask)
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = x.float()  # Convert to FloatTensor

        x = self.conv1(x, edge_index, mask)
        for conv in self.convs:
            x = conv(x, edge_index, mask)

        x = global_mean_pool(x, batch)
        # NOTE: the linear layers from here do not contribute significantly to run-time
        # Therefore you probably don't want to quantize these as it will likely have 
        # an impact on performance.
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        # NOTE: This is a quantized final layer. You probably don't want to be
        # this aggressive in practice.
        x = self.lin2(x)
        return F.log_softmax(x, dim=-1)

    
model = GIN(
    dataset,
    num_layers=args.num_layers,
    hidden=args.hidden,
    dq=args.DQ,
    qypte=qypte,
    ste=ste,
    momentum=momentum,
    percentile=None,
    sample_prop=args.sample_prop,
)
    

# Helpful Function

In [21]:
def k_fold(dataset, folds):
    skf = StratifiedKFold(folds, shuffle=True, random_state=12345)
      # Aggregate labels across tasks (e.g., sum or mean)
    y = dataset.data.y.sum(dim=1)  # Sum across tasks
    # OR select a single task
    y = dataset.data.y[:, 0]  # Use the first task for stratification
    #print(dataset.data.y.shape)

    test_indices, train_indices = [], []
    for _, idx in skf.split(torch.zeros(len(dataset)), y):
        test_indices.append(torch.from_numpy(idx))

    val_indices = [test_indices[i - 1] for i in range(folds)]

    for i in range(folds):
        train_mask = torch.ones(len(dataset), dtype=torch.bool)
        train_mask[test_indices[i]] = 0
        train_mask[val_indices[i]] = 0
        train_indices.append(train_mask.nonzero().view(-1))

    return train_indices, test_indices, val_indices

def num_graphs(data):
    if data.batch is not None:
        return data.num_graphs
    else:
        return data.x.size(0)

### Functions for Mmeasuring criterias

In [22]:
def get_num_parameters(model: nn.Module, count_nonzero_only=False) -> int:
    """
    calculate the total number of parameters of model
    :param count_nonzero_only: only count nonzero weights
    """
    num_counted_elements = 0
    for param in model.parameters():
        if count_nonzero_only:
            num_counted_elements += param.count_nonzero()
        else:
            num_counted_elements += param.numel()
    return num_counted_elements

# Function to get CPU usage
def get_cpu_usage():
    return psutil.cpu_percent(interval=1)



# Function to approximate power consumption (Assume some average power usage per CPU percentage point)
def estimate_power_usage(cpu_usage):
    base_power_usage = 10  # Assumed base power usage in watts
    power_per_percent = 0.5  # Assumed additional watts per CPU usage percent
    return base_power_usage + (power_per_percent * cpu_usage)

# The model size based on the number of parameters
import os
import torch
import torch.nn as nn

def calculate_model_size(model: nn.Module, 
                         qypte: str = 'fp32', 
                         include_metadata: bool = False,
                         model_path: str = None) -> float:
    """
    Calculate model size in KB/MB for different precisions.
    
    Args:
        model: PyTorch model
        precision: 'fp32' (32-bit float) or 'int4' (4-bit integer)
        include_metadata: Whether to include PyTorch metadata in size calculation
        model_path: If provided, will check actual file size on disk
        
    Returns:
        Size in KB (if include_metadata=False) or actual file size (if include_metadata=True)
    """
    # Get total number of parameters
    total_params = sum(p.numel() for p in model.parameters())
    
    # Calculate theoretical size
    if qypte == 'FP32':
        size_bits = total_params * 32
    elif qypte== 'INT4':
        size_bits = total_params * 4
    elif qypte == 'INT8':
        size_bits = total_params * 8    
   
    
    size_bytes = size_bits / 8
    size_kb = size_bytes / 1024
    
    # If checking actual file size
    if include_metadata and model_path:
        if not os.path.exists(model_path):
            # Save model to temporary file if path doesn't exist
            torch.save(model.state_dict(), model_path)
        actual_size_kb = os.path.getsize(model_path) / 1024
        return actual_size_kb
    
    return size_kb




In [23]:
## True
def train(model, optimizer, loader):
    model.train()
    total_loss = 0
    for data in loader:
        optimizer.zero_grad()
        data = data.to(device)
        out = model(data)
        target = data.y.view(-1).long()  # Ensure target is LongTensor

        # Check if the last batch is smaller
        if out.size(0) != target.size(0):
            target = target[:out.size(0)]  # Truncate target to match output size

        loss = F.nll_loss(out, target)  # Use the converted target
        loss.backward()
        total_loss += loss.item() * num_graphs(data)
        optimizer.step()
    return total_loss / len(loader.dataset)



def eval_loss(model, loader):
    model.eval()

    loss = 0
    for data in loader:
        data = data.to(device)
        with torch.no_grad():
            out = model(data)
            # Convert data.y to LongTensor
            target = data.y.view(-1).long()  # Ensure target is LongTensor
            loss += F.nll_loss(out, target, reduction="sum").item()
    return loss / len(loader.dataset)


def eval_acc(model, loader):
    model.eval()

    correct = 0
    for data in loader:
        data = data.to(device)
        with torch.no_grad():
            pred = model(data).max(1)[1]
            # Convert data.y to LongTensor
            target = data.y.view(-1).long()  # Ensure target is LongTensor
            correct += pred.eq(target).sum().item()
    return correct / len(loader.dataset)

## Training Process

In [24]:

def cross_validation_with_val_set(
    dataset,
    model,
    folds,
    epochs,
    batch_size,
    lr,
    lr_decay_factor,
    lr_decay_step_size,
    weight_decay,
    qypte='FP32',
    use_tqdm=True,
    writer=None,
    logger=None,
):

        
        val_losses, accs, durations = [], [], []
        quant_model_accuracy=[]
        quant_model_loss=[]
        t_quant_model=[]
        Num_parm_quant_model=[]
        quant_model_size=[]
        quant_energy_consumption=[]
        quant_cpu_usage=[]
        quant_memory_usage=[]
        max_acc=0.4
        
       
        # Initialize a dictionary to store all results per iteration
        Eva_iter = {
            "val losses per iter": [],
            "durations per iter": [],
            "quant model accuracy per iter": [],
            "time inference of quant model per iter": [],
            "number parmameters of quant model per iter": [],  # Store the best accuracy for each fold
            "size of quant model per iter": [],
            "energy consumption of quant model per iter": [],
            "cpu usage of quant model per iter": [],
            "total memory usage of quant model per iter": [],
            "final_metrics": {}  # Store final metrics (mean, std, etc.)
        }
        
        for fold, (train_idx, test_idx, val_idx) in enumerate(zip(*k_fold(dataset, folds))):
            train_dataset = dataset[train_idx.tolist()]
            test_dataset = dataset[test_idx.tolist()]
            val_dataset = dataset[val_idx.tolist()]
            if "adj" in train_dataset[0]:
                    train_loader = DenseLoader(train_dataset, batch_size, shuffle=True)
                    val_loader = DenseLoader(val_dataset, batch_size, shuffle=False)
                    test_loader = DenseLoader(test_dataset, batch_size, shuffle=False)
            else:
                    train_loader = DataLoader(train_dataset, batch_size, shuffle=True)
                    val_loader = DataLoader(val_dataset, batch_size, shuffle=False)
                    test_loader = DataLoader(test_dataset, batch_size, shuffle=False)

            model.to(device).reset_parameters()
            optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

            if torch.cuda.is_available():
                torch.cuda.synchronize()

            t_start = time.perf_counter()

            #if use_tqdm:
                #t = tqdm(total=epochs, desc="Fold #" + str(fold))
            Eva_fold= OrderedDict() #It is a dictionary to arrange output of this fold
            for epoch in range(1, epochs + 1):
                    train_loss = train(model, optimizer, train_loader)
                    val_loss = eval_loss(model, val_loader)
                    val_losses.append(val_loss)
                    
                    accs.append(eval_acc(model, test_loader))
                    eval_info = {
                        "fold": fold,
                        "epoch": epoch,
                        "train_loss": train_loss,
                        "val_loss": val_losses[-1],
                        "test_acc": accs[-1],
                    }
                    acc_test=accs[-1]  
                
                    if logger is not None:
                        logger(eval_info)

                    if writer is not None:
                        writer.add_scalar(f"Fold{fold}/Train_Loss", train_loss, epoch)
                        writer.add_scalar(f"Fold{fold}/Val_Loss", val_loss, epoch)
                        writer.add_scalar(
                           f"Fold{fold}/Lr", optimizer.param_groups[0]["lr"], epoch
                        )

                    if epoch % lr_decay_step_size == 0:
                        for param_group in optimizer.param_groups:
                            param_group["lr"] = lr_decay_factor * param_group["lr"]

                    if epoch % 30 == 0:
                        print(f"Eval Epoch: {epoch} |Val_loss:{val_loss:.03f}| Train_Loss: {train_loss:.3f} | Acc_Val: {val_losses[-1]:.3f}|Fold: {fold}")
                   

                
                    if(acc_test>max_acc):
                        path =  path2check+'/'+args.model+'_'+dataset_name+'_'+'quantized.pth.tar'
                        #path = dir+'/'+args.model+'_'+dataset_name+'_'+str(bit)+'bit'+'quantized.pth.tar'
                        max_acc = acc_test
                        torch.save({'state_dict': model.state_dict(), 'best_accu': acc_test}, path)
  
                    if torch.cuda.is_available():
                        torch.cuda.synchronize()
                    t_end = time.perf_counter()
                    durations.append(t_end - t_start)
                    
            # Start monitoring CPU and memory usage, model size, number of parametes, time inference and  power consumption
            
            
            quant_model_path= path2check+'/'+args.model+'_'+dataset_name+'_'+'quantized.pth.tar'
            #state = torch.load(quant_model_path)
            #dict=state['state_dict']
            #recover_model = lambda: model.load_state_dict(state['state_dict'])
            
            gc.collect()
            time.sleep(5)  # Add a 5-second delay to stabilize the initial state
            tracemalloc.start()  # Start tracking memory allocations
            snapshot_before = tracemalloc.take_snapshot()#take a snapshot of the current memory state before starting the measurement.

            t0 = time.perf_counter()
            initial_cpu_usage = get_cpu_usage()
            power_usage = estimate_power_usage(initial_cpu_usage)


            fold_quant_model_accuracy= eval_acc(model, test_loader)

            fold_quant_cpu_usage = get_cpu_usage()
            t1 = time.perf_counter()
            fold_t_quant_model=t1-t0

            snapshot_after = tracemalloc.take_snapshot()
            tracemalloc.stop()
            top_stats = snapshot_after.compare_to(snapshot_before, 'lineno')

            folde_quant_total_memory_diff = sum([stat.size_diff for stat in top_stats])
            fold_quant_energy_consumption = power_usage * fold_t_quant_model
            #fold_quant_model_size = os.path.getsize(main_model_path)
            fold_quant_model_size =calculate_model_size(model, qypte )
            fold_num_parm_quant_model=get_num_parameters(model, count_nonzero_only=True)

            gc.collect()
            time.sleep(5) 
            #Update Eva dictionary
            Eva_fold.update({'quant model accuracy per fold': fold_quant_model_accuracy,
                        'time inference of quant model per fold':fold_t_quant_model,
                        'number parmameters of quant model per fold': fold_num_parm_quant_model,
                        'size of quant model per fold': fold_quant_model_size, 
                        'energy consumption of quant model per fold':fold_quant_energy_consumption,
                        'total memory usage of quant model per fold':folde_quant_total_memory_diff,
                        'cpu usage of quant model per fold':fold_quant_cpu_usage
                       })
            
            gc.collect()
            time.sleep(5) 
   

            quant_model_accuracy.append(Eva_fold['quant model accuracy per fold'])
            t_quant_model.append(Eva_fold['time inference of quant model per fold'])
            Num_parm_quant_model.append(int(Eva_fold['number parmameters of quant model per fold']))
            quant_model_size.append(int(Eva_fold['size of quant model per fold']))
            quant_energy_consumption.append(Eva_fold['energy consumption of quant model per fold'])
            quant_cpu_usage.append(Eva_fold['cpu usage of quant model per fold'])
            quant_memory_usage.append(Eva_fold['total memory usage of quant model per fold'])

           

     
     
        Eva_iter["quant model accuracy per iter"]= stat.mean(quant_model_accuracy)
        Eva_iter["time inference of quant model per iter"]= stat.mean(t_quant_model)
        Eva_iter["number parmameters of quant model per iter"]=  stat.mean(Num_parm_quant_model)
        Eva_iter["size of quant model per iter"]= stat.mean(quant_model_size)
        Eva_iter["energy consumption of quant model per iter"]= stat.mean(quant_energy_consumption)
        Eva_iter["cpu usage of quant model per iter"]= stat.mean(quant_cpu_usage)
        Eva_iter["total memory usage of quant model per iter"]= stat.mean(quant_memory_usage)
    
    
        loss, acc, duration = tensor(val_losses), tensor(accs), tensor(durations)
        loss, acc = loss.view(folds, epochs), acc.view(folds, epochs)
        loss, argmin = loss.min(dim=1)
        acc = acc[torch.arange(folds, dtype=torch.long), argmin]

        Eva_iter["val losses per iter"]= loss.mean().item()
        Eva_iter["durations per iter"]= duration.mean().item()

        
      
        
        return Eva_iter, model

### Manual Measurement

In [25]:
# The following are all list of criteria for measurements. 
# We collect all desired datas of each list across iterations. 
# Then, we compute average and std of each list.



#quant model
Quant_val_loss=[]
Quant_duration=[]
Quant_model_accuracy=[]
T_quant_model=[]
Num_parm_quant_model=[]
Quant_model_size=[]
Quant_Energy_Consumption=[]
Quant_Cpu_Usage=[]
Quant_Memory_Usage=[]


# Here is the dictionary to record the list of all measurements
Eva_measure={'quant validation loss':Quant_val_loss,
             'quant duration':Quant_duration,
            'quant model accuracy': Quant_model_accuracy,
            'time inference of quant model':T_quant_model,
            'number parmameters of quant model':Num_parm_quant_model,
            'quant model size':Quant_model_size,
            'energy consumption of quant model':Quant_Energy_Consumption,
            'cpu usage of quant model':Quant_Cpu_Usage,
            'memory usage of quant model':Quant_Memory_Usage}


In [26]:
iterations=1
epochs=2
folds=10
batch_size=args.batch_size
lr=args.lr
lr_decay_factor=args.lr_decay_factor
lr_decay_step_size=args.lr_decay_step_size
weight_decay=args.wd
writer=writer
logger=None
use_tqdm=True

In [27]:
#### load the quantized  model

for i in range(iterations):
    print('********************************************')
    print(f'The iteration is :{i+1} ') 

  
    Eva_iter, model=cross_validation_with_val_set(
                                            dataset,
                                            model,
                                            folds,
                                            epochs,
                                            batch_size,
                                            lr,
                                            lr_decay_factor,
                                            lr_decay_step_size,
                                            weight_decay,
                                            qypte,
                                            use_tqdm=True,
                                            writer=None,
                                            logger=None,)


 
    Quant_val_loss.append(Eva_iter["val losses per iter"])
    Quant_duration.append(Eva_iter["durations per iter"])
    Quant_model_accuracy.append(Eva_iter["quant model accuracy per iter"])
    T_quant_model.append(Eva_iter["time inference of quant model per iter"])
    Num_parm_quant_model.append(Eva_iter["number parmameters of quant model per iter"])
    Quant_model_size.append(Eva_iter["size of quant model per iter"])
    Quant_Energy_Consumption.append(Eva_iter["energy consumption of quant model per iter"])
    Quant_Cpu_Usage.append( Eva_iter["cpu usage of quant model per iter"])
    Quant_Memory_Usage.append(Eva_iter["total memory usage of quant model per iter"])
 

********************************************
The iteration is :1 




In [31]:
Eva_iter

{'val losses per iter': 0.5456306338310242,
 'durations per iter': 61.01333999633789,
 'quant model accuracy per iter': 0.7128658359895682,
 'time inference of quant model per iter': 2.8828856599982826,
 'number parmameters of quant model per iter': 43015,
 'size of quant model per iter': 168,
 'energy consumption of quant model per iter': 118.64256476242474,
 'cpu usage of quant model per iter': 61.76,
 'total memory usage of quant model per iter': 48523.7,
 'final_metrics': {}}

In [21]:
# This is a dictionary to save all measurements. Aftre measuring, we can compute mean and std of each item.
from collections import OrderedDict 
Eva_final = OrderedDict()



quant_model_val_loss_mean =stat.mean(Quant_val_loss)
quant_model_val_loss_std = stat.stdev(Quant_val_loss)
#desc = "{:.3f} ± {:.3f}".format(acc_mean,acc_std)
Eva_final.update({'Ave of quant loss validation':float(format(quant_model_val_loss_mean, '.3f'))})
Eva_final.update({'Std of quant loss validation':float(format(quant_model_val_loss_std, '.3f'))})    

quant_model_duration_mean =stat.mean(Quant_duration)
quant_model_duration_std = stat.stdev(Quant_duration)
#desc = "{:.3f} ± {:.3f}".format(acc_mean,acc_std)
Eva_final.update({'Ave of quant model duration':float(format(quant_model_duration_mean , '.3f'))})
Eva_final.update({'Std of quant model duration':float(format(quant_model_duration_std, '.3f'))})                                         
                                     

quant_model_accuracy_mean =stat.mean(Quant_model_accuracy)
quant_model_accuracy_std = stat.stdev(Quant_model_accuracy)
#desc = "{:.3f} ± {:.3f}".format(acc_mean,acc_std)
Eva_final.update({'Ave of quant model accuracy':float(format(quant_model_accuracy_mean, '.3f'))})
Eva_final.update({'Std of quant model accuracy':float(format(quant_model_accuracy_std, '.3f'))})
                 

t_quant_model_mean = stat.mean(T_quant_model)
t_quant_model_std =stat.stdev(T_quant_model)
#desc = "{:.3f} ± {:.3f}".format(acc_mean,acc_std)
Eva_final.update({'Ave of time inference of quant model':float(format(t_quant_model_mean, '.3f'))})
Eva_final.update({'Std of time inference of quant model':float(format(t_quant_model_std, '.3f'))})

num_parm_quant_model_mean = stat.mean(Num_parm_quant_model)
num_parm_quant_model_std = stat.stdev(Num_parm_quant_model)
#desc = "{:.3f} ± {:.3f}".format(acc_mean,acc_std)
Eva_final.update({'Ave of number parmameters of quant model':num_parm_quant_model_mean})
Eva_final.update({'Std of number parmameters of quant model':num_parm_quant_model_std})

quant_model_size_mean =stat.mean( Quant_model_size)
quant_model_size_std = stat.stdev(Quant_model_size)
#desc = "{:.3f} ± {:.3f}".format(acc_mean,acc_std)
Eva_final.update({'Ave of quant model size':quant_model_size_mean})
Eva_final.update({'Std of quant_model_size':quant_model_size_std })

quant_energy_consumption_mean = stat.mean(Quant_Energy_Consumption)
quant_energy_consumption_std = stat.stdev(Quant_Energy_Consumption)
#desc = "{:.3f} ± {:.3f}".format(acc_mean,acc_std)
Eva_final.update({'Ave of energy consumption of quant model':quant_energy_consumption_mean })
Eva_final.update({'Std of energy consumption of quant model':quant_energy_consumption_std})


quant_cpu_usage_mean = stat.mean(Quant_Cpu_Usage)
quant_cpu_usage_std = stat.stdev(Quant_Cpu_Usage)
#desc = "{:.3f} ± {:.3f}".format(acc_mean,acc_std)
Eva_final.update({'Ave of cpu usage of quant model':quant_cpu_usage_mean})
Eva_final.update({'Std of cpu usage of quant model':quant_cpu_usage_std})

quant_memory_usage_mean = stat.mean(Quant_Memory_Usage)
quant_memory_usage_std = stat.stdev(Quant_Memory_Usage)
#desc = "{:.3f} ± {:.3f}".format(acc_mean,acc_std)
Eva_final.update({'Ave of memory usage of quant model':quant_memory_usage_mean})
Eva_final.update({'Std of memory usage of quant model':quant_memory_usage_std})

#################################


# Determing Quantization Method 
if args.DQ == True:
    dq='DQ'
else:
    dq='QAT'
print(f"All measurement about {dq} Quantization process of type:{ qypte} on modes:{args.DQ}  ")   
Eva_final

All measurement about DQ Quantization process of type:INT4 on modes:True  


OrderedDict([('Ave of quant loss validation', 0.547),
             ('Std of quant loss validation', 0.005),
             ('Ave of quant model duration', 147.96),
             ('Std of quant model duration', 40.294),
             ('Ave of quant model accuracy', 0.712),
             ('Std of quant model accuracy', 0.0),
             ('Ave of time inference of quant model', 3.218),
             ('Std of time inference of quant model', 0.436),
             ('Ave of number parmameters of quant model', 43014.95),
             ('Std of number parmameters of quant model', 0.07071067811762578),
             ('Ave of quant model size', 226429),
             ('Std of quant_model_size', 0.0),
             ('Ave of energy consumption of quant model', 166.41279508215425),
             ('Std of energy consumption of quant model', 63.67225641116794),
             ('Ave of cpu usage of quant model', 77.35),
             ('Std of cpu usage of quant model', 31.918800102760756),
             ('Ave of memo

### Recording the output

In [22]:
# Determining mode
if args.ste_abs:
    mode = 'ste_abs'
elif args.ste_mom:
    mode = 'ste_mom'   
elif args.gc_abs:
    mode = 'gc_abs'
elif args.gc_mom:
    mode = 'gc_mom'
elif args.ste_per:
    mode = 'ste_per'
elif args.gc_per:
    mode = 'gc_per'


In [23]:

qypte='INT4'
dataset_name='BBBP'

if args.DQ == True:
    dq='DQ'
else:
    dq='QAT'


file_name = path2result+'/'+'Method_type_'+ qypte +'_and_Quantization_is_'+dq+'_On_'+dataset_name+'_with_Mode_'+mode+'.txt'

with open(file_name, 'w') as f:
    for key, value in vars(args).items():
        f.write('%s:%s\n'%(key, value))

    for key, value in Eva_final.items():
        f.write('%s:%s\n'%(key, value))

    for key, value in Eva_measure.items():
        f.write('%s:%s\n' % (key, ','.join(map(str, value))))    

In [45]:
def analysis_bit(dataset, state_dict, all_positive=True):
    # Step 1: Collect layer-wise bit parameters
    layer_bits = {}
  
    for key, param in state_dict.items():
        if'quant' in key  and 'fea' in key:
            layer_name = key.split('.quant_fea')[0]
            layer_bits[layer_name] = param.abs().round() - 1

    # Step 2: Per-graph analysis
    for i, data in enumerate(dataset):
        #print(f"\n===== Analyzing Graph {i+1}/{len(dataset)} =====")
        edge_index = data.edge_index
        row, col = edge_index
        deg = degree(col, data.x.size(0)).cpu()
        
        # Step 3: Per-layer analysis within current graph
        for layer_name, bits in layer_bits.items():
            # Skip if bits tensor doesn't match current graph size
            if bits.size(0) != deg.size(0):
                continue
                
            print(f"\nLayer {layer_name}:")
            print(f"Avg bits: {bits.mean().item():.2f}")
            
            # Bit-degree correlation
            for bit_val in range(0, 9):
                mask = (bits == bit_val)
                if mask.sum() > 0:
                    avg_deg = deg[mask].mean().item()
                    print(f"  {bit_val}-bit nodes: {mask.sum().item()} nodes, Avg Degree={avg_deg:.1f}")
    
    # Step 4: Weight quantization analysis
    weight_bits = []
    for key, param in state_dict.items():
        if'quant' in key:
            print(param)
            bits = param.abs().round() - 1
            print(bits)
            weight_bits.append(bits.mean().item())
    
    print("\n===== Weight Quantization Summary =====")
    if weight_bits:
        print(f"Avg weight bits: {sum(weight_bits)/len(weight_bits):.2f}")
    else:
        print("No weight quantization parameters found")
    
    print("Analysis complete")
    return sum(weight_bits)/len(weight_bits)
   

In [46]:
## Save model by Danny bit =8
model.to(device).reset_parameters()
quant_model_path= path2check+'/'+args.model+'_'+dataset_name+'_'+'quantized.pth.tar'

state = torch.load(quant_model_path,map_location=torch.device('cpu'))
dict=state['state_dict']
analysis_bit(dataset, model.state_dict(), all_positive=True)


Layer conv1.nn.2.layer_quant.features.min_val:
Avg bits: nan

Layer conv1.nn.2.layer_quant.features.max_val:
Avg bits: nan

Layer convs.0.nn.0.layer_quant.features.min_val:
Avg bits: nan

Layer convs.0.nn.0.layer_quant.features.max_val:
Avg bits: nan

Layer convs.0.nn.2.layer_quant.features.min_val:
Avg bits: nan

Layer convs.0.nn.2.layer_quant.features.max_val:
Avg bits: nan

Layer convs.1.nn.0.layer_quant.features.min_val:
Avg bits: nan

Layer convs.1.nn.0.layer_quant.features.max_val:
Avg bits: nan

Layer convs.1.nn.2.layer_quant.features.min_val:
Avg bits: nan

Layer convs.1.nn.2.layer_quant.features.max_val:
Avg bits: nan

Layer convs.2.nn.0.layer_quant.features.min_val:
Avg bits: nan

Layer convs.2.nn.0.layer_quant.features.max_val:
Avg bits: nan

Layer convs.2.nn.2.layer_quant.features.min_val:
Avg bits: nan

Layer convs.2.nn.2.layer_quant.features.max_val:
Avg bits: nan

Layer convs.3.nn.0.layer_quant.features.min_val:
Avg bits: nan

Layer convs.3.nn.0.layer_quant.features.max

Avg bits: nan

Layer convs.2.nn.0.layer_quant.features.min_val:
Avg bits: nan

Layer convs.2.nn.0.layer_quant.features.max_val:
Avg bits: nan

Layer convs.2.nn.2.layer_quant.features.min_val:
Avg bits: nan

Layer convs.2.nn.2.layer_quant.features.max_val:
Avg bits: nan

Layer convs.3.nn.0.layer_quant.features.min_val:
Avg bits: nan

Layer convs.3.nn.0.layer_quant.features.max_val:
Avg bits: nan

Layer convs.3.nn.2.layer_quant.features.min_val:
Avg bits: nan

Layer convs.3.nn.2.layer_quant.features.max_val:
Avg bits: nan

Layer lin1.layer_quant.features.min_val:
Avg bits: nan

Layer lin1.layer_quant.features.max_val:
Avg bits: nan

Layer lin2.layer_quant.features.min_val:
Avg bits: nan

Layer lin2.layer_quant.features.max_val:
Avg bits: nan

Layer conv1.nn.2.layer_quant.features.min_val:
Avg bits: nan

Layer conv1.nn.2.layer_quant.features.max_val:
Avg bits: nan

Layer convs.0.nn.0.layer_quant.features.min_val:
Avg bits: nan

Layer convs.0.nn.0.layer_quant.features.max_val:
Avg bits: na

nan

In [79]:

edge_index = dataset.edge_index
row, col = edge_index
deg = degree(col, dataset.x.size(0)).cpu()
weight_bits=[]
for key, param in dict.items():
    
        if'quant' in key and 'fea' in key:
            print(key)
            bit = param.abs().round() - 1
            #print(bits)
            #if bit.mean().item()>0:
            weight_bits.append(bit.mean().item())
            print('The average bits of current layer:',bit.mean())
            for i in range(32): 
                if (bit==i).sum().item()!=0:
                     print(f"{i}bit:{format((bit==i).sum())} ")
            
            print('\n')
            print('The average degree of the nodes using corresponding bitwidth:')
            for i in range(32):
                index_bit = torch.where(bit==i)[0]
                #print(index_bit )
         
                if  index_bit.nelement() !=0:
                    print(f"{i}bit_deg_mean:{deg[index_bit].mean()}")
            
            print('\n')

print("Analysis complete")
   
      
sum(weight_bits)/len(weight_bits)    

conv1.nn.2.layer_quant.features.min_val
The average bits of current layer: tensor(21.)
21bit:1 


The average degree of the nodes using corresponding bitwidth:
21bit_deg_mean:0.0


conv1.nn.2.layer_quant.features.max_val
The average bits of current layer: tensor(32.)


The average degree of the nodes using corresponding bitwidth:


convs.0.nn.0.layer_quant.features.min_val
The average bits of current layer: tensor(16.)
16bit:1 


The average degree of the nodes using corresponding bitwidth:
16bit_deg_mean:0.0


convs.0.nn.0.layer_quant.features.max_val
The average bits of current layer: tensor(17.)
17bit:1 


The average degree of the nodes using corresponding bitwidth:
17bit_deg_mean:0.0


convs.0.nn.2.layer_quant.features.min_val
The average bits of current layer: tensor(11.)
11bit:1 


The average degree of the nodes using corresponding bitwidth:
11bit_deg_mean:0.0


convs.0.nn.2.layer_quant.features.max_val
The average bits of current layer: tensor(13.)
13bit:1 


The average degre

47.0

In [80]:
for key, param in dict.items():
    print(key)

conv1.eps
conv1.nn.0.weight
conv1.nn.0.bias
conv1.nn.2.weight
conv1.nn.2.bias
conv1.nn.2.layer_quant.inputs.min_val
conv1.nn.2.layer_quant.inputs.max_val
conv1.nn.2.layer_quant.features.min_val
conv1.nn.2.layer_quant.features.max_val
conv1.nn.2.layer_quant.weights.min_val
conv1.nn.2.layer_quant.weights.max_val
conv1.nn.4.weight
conv1.nn.4.bias
conv1.nn.4.running_mean
conv1.nn.4.running_var
conv1.nn.4.num_batches_tracked
conv1.mp_quantizers.message_low.min_val
conv1.mp_quantizers.message_low.max_val
conv1.mp_quantizers.update_low.min_val
conv1.mp_quantizers.update_low.max_val
conv1.mp_quantizers.aggregate_low.min_val
conv1.mp_quantizers.aggregate_low.max_val
convs.0.eps
convs.0.nn.0.weight
convs.0.nn.0.bias
convs.0.nn.0.layer_quant.inputs.min_val
convs.0.nn.0.layer_quant.inputs.max_val
convs.0.nn.0.layer_quant.features.min_val
convs.0.nn.0.layer_quant.features.max_val
convs.0.nn.0.layer_quant.weights.min_val
convs.0.nn.0.layer_quant.weights.max_val
convs.0.nn.2.weight
convs.0.nn.2.bias


In [17]:
def get_quantized_size(model):
    total_bits = 0
    
    # Helper to get bitwidth from IntegerQuantizer
    def get_bitwidth(quantizer):
        if hasattr(quantizer, 'bitwidth'):
            return quantizer.bitwidth
        # Default to 8 bits if bitwidth isn't explicitly set
        return 8  
    
    for name, module in model.named_modules():
        # Count quantized linear layers
        if isinstance(module, LinearQuantized):
            # Weights
            if hasattr(module, 'layer_quant') and 'weights' in module.layer_quant:
                quantizer = module.layer_quant['weights']
                bitwidth =4 #get_bitwidth(quantizer)
                total_bits += module.weight.numel() * bitwidth
            
            # Biases (typically remain FP32 - 32 bits)
            if module.bias is not None:
                total_bits += module.bias.numel() * 32
        
        # Count BatchNorm parameters (typically FP32)
        elif isinstance(module, nn.BatchNorm1d):
            total_bits += module.weight.numel() * 32  # gamma
            total_bits += module.bias.numel() * 32    # beta
            # Running stats (not parameters but still stored)
            total_bits += module.running_mean.numel() * 32
            total_bits += module.running_var.numel() * 32
    
    # Convert to bytes
    return total_bits / 8