$A^2Q$ Quantization method of Proteins dataset Training by GIN Models
------------

## Packages and Libraries

In [1]:
import os
import numpy as np
import random
import statistics as stat
from tabulate import tabulate
from sklearn.model_selection import StratifiedKFold
from tqdm import tqdm
import time
import argparse
from collections import OrderedDict
import os
import glob
import json
import torch
import pickle
import numpy as np
import os.path as osp

# CPU and Enegusage 
import psutil
import itertools
import tracemalloc
import gc


import torch
import torch.nn as nn
from torch import tensor
from torch.nn import Linear, Sequential, ReLU, Identity, BatchNorm1d as BN
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree,remove_self_loops
from torch_scatter import scatter_mean
from torch_geometric.data import Data
from torch.autograd.function import InplaceFunction
from torch_geometric.nn import GCNConv,GINConv,global_mean_pool,TopKPooling

# For downloading Proteins dataset from TUDataset and Transformation
from torch_geometric.datasets import TUDataset,Planetoid,GNNBenchmarkDataset
import torch_geometric.transforms as T
from torch_geometric.loader import DataLoader


# AAQ Quantization
from quantize_function.u_quant_gc_bit_debug import *
from quantize_function.MessagePassing_gc_bit import GINConvMultiQuant
from quantize_function.get_scale_index import get_deg_index, get_scale_index
from utils.quant_utils import analysis_bit


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
# For downloading BBBP dataset from MoleculeNet and Transformation
from torch_geometric.datasets import MoleculeNet
from torch_geometric.utils import dense_to_sparse
from torch.utils.data import random_split, Subset
from torch_geometric.data import Data, InMemoryDataset, DataLoader
import torch_geometric.transforms as T

### Functions for Mmeasuring criterias

In [3]:
def get_num_parameters(model: nn.Module, count_nonzero_only=False) -> int:
    """
    calculate the total number of parameters of model
    :param count_nonzero_only: only count nonzero weights
    """
    num_counted_elements = 0
    for param in model.parameters():
        if count_nonzero_only:
            num_counted_elements += param.count_nonzero()
        else:
            num_counted_elements += param.numel()
    return num_counted_elements

# Function to get CPU usage
def get_cpu_usage():
    return psutil.cpu_percent(interval=1)



# Function to approximate power consumption (Assume some average power usage per CPU percentage point)
def estimate_power_usage(cpu_usage):
    base_power_usage = 10  # Assumed base power usage in watts
    power_per_percent = 0.5  # Assumed additional watts per CPU usage percent
    return base_power_usage + (power_per_percent * cpu_usage)

# The model size based on the number of parameters
def calculate_model_size(model: nn.Module, 
                         qypte: str = 'fp32', 
                         include_metadata: bool = False,
                         model_path: str = None) -> float:
    """
    Calculate model size in KB/MB for different precisions.
    
    Args:
        model: PyTorch model
        precision: 'fp32' (32-bit float) or 'int4' (4-bit integer)
        include_metadata: Whether to include PyTorch metadata in size calculation
        model_path: If provided, will check actual file size on disk
        
    Returns:
        Size in KB (if include_metadata=False) or actual file size (if include_metadata=True)
    """
    # Get total number of parameters
    total_params = sum(p.numel() for p in model.parameters())
    
    # Calculate theoretical size
    if qypte == 'FP32':
        size_bits = total_params * 32
    elif qypte== 'INT4':
        size_bits = total_params * 4
    elif qypte == 'INT8':
        size_bits = total_params * 8    
   
    
    size_bytes = size_bits / 8
    size_kb = size_bytes / 1024
    
    # If checking actual file size
    if include_metadata and model_path:
        if not os.path.exists(model_path):
            # Save model to temporary file if path doesn't exist
            torch.save(model.state_dict(), model_path)
        actual_size_kb = os.path.getsize(model_path) / 1024
        return actual_size_kb
    
    return size_kb


## Function for Quatization

In [4]:
def paras_group(model):
    all_params = model.parameters()
    weight_paras=[]
    quant_paras_bit_weight = []
    quant_paras_bit_fea = []
    quant_paras_scale_weight = []
    quant_paras_scale_fea = []
    quant_paras_scale_xw = []
    quant_paras_bit_xw = []
    other_paras = []
    for name,para in model.named_parameters():        
        if('quant' in name and 'bit' in name and 'weight' in name):
            quant_paras_bit_weight+=[para]
            # para.requires_grad = False
        elif('quant' in name and 'bit' in name and 'fea' in name):
            quant_paras_bit_fea+=[para]
        elif('quant' in name and 'bit' not in name and 'weight' in name):
            quant_paras_scale_weight+=[para]
            # para.requires_grad = False
        elif('quant' in name and 'bit' not in name and 'fea' in name):
            quant_paras_scale_fea+=[para]
        elif('xw'in name and 'q' in name and 'bit' not in name):
            quant_paras_scale_xw+=[para]
        elif('xw'in name and 'q' in name and 'bit' in name):
            quant_paras_bit_xw+=[para]
        elif('weight' in name and 'quant' not in name ):
            weight_paras+=[para]
    params_id = list(map(id,quant_paras_bit_fea))+list(map(id,quant_paras_bit_weight))+list(map(id,quant_paras_scale_weight))+list(map(id,quant_paras_scale_fea))+list(map(id,weight_paras))\
    +list(map(id,quant_paras_scale_xw))+list(map(id,quant_paras_bit_xw))
    other_paras = list(filter(lambda p: id(p) not in params_id, all_params))
    return weight_paras,quant_paras_bit_weight,quant_paras_bit_fea,quant_paras_scale_weight,quant_paras_scale_fea,quant_paras_scale_xw,quant_paras_bit_xw,other_paras

def setup_seed(seed):
      torch.manual_seed(seed)
      torch.cuda.manual_seed_all(seed)
      np.random.seed(seed)
      random.seed(seed)
    #  torch.backends.cudnn.deterministic = True

def parameter_stastic(model,dataset,hidden_units):
    w_Byte = 0
    a_Byte = 0
    for name, par in model.named_parameters():
        if(('bit' in name)&('weight' in name)):
            if('conv1' in name):
                scale = dataset.num_node_features
            else:
                scale = hidden_units
            par = torch.floor(par)
            w_Byte = scale*par.sum()/8./1024.+w_Byte
        elif(('bit' in name)&('fea' in name)):
            if('conv1' in name):
                a_scale = 0
            else:
                a_scale = hidden_units
            # a_scale = dataset.data.num_nodes
            par = torch.floor(par)
            a_Byte = a_scale*par.sum()/8./1024.+a_Byte
    return w_Byte, a_Byte

class ResettableSequential(nn.Sequential):
    def reset_parameters(self):
        for child in self.children():
            if hasattr(child, "reset_parameters"):
                child.reset_parameters()
    def forward(self,input,edge_index,bit_sum):
        for model in self:
            input,_,bit_sum = model(input,edge_index,bit_sum)
        return input,bit_sum



In [5]:
# Relu and Batch Normalization
class relu(nn.Module):
    def __init__(self,):
        super().__init__()
    def forward(self,x,edge_index,bit_sum):
        x[x<0] = 0
        return x,edge_index,bit_sum

class bn(nn.Module):
    def __init__(self,hidden_units):
        super().__init__()
        self.bn = nn.BatchNorm1d(hidden_units)
    def forward(self,x,edge_index,bit_sum):
        x = self.bn(x)
        return x,edge_index,bit_sum

## qGIN Model with Quatization

In [22]:

class qGIN(nn.Module):
    def __init__(self, dataset, num_layers, hidden_units, bit, num_deg=1000, is_q=True,
                    uniform=False,init='norm'):
        super(qGIN, self).__init__()
        gin_layer = GINConvMultiQuant
        self.bit = bit
        para_list=[[{'alpha_init':0.01,'gama_init':0.01,'alpha_std':0.1,'gama_std':0.1},{'alpha_init':0.01,'gama_init':0.01,'alpha_std':0.1,'gama_std':0.1},{'gama_init':0.70,'gama_std':0.1}],
                   [{'alpha_init':0.01,'gama_init':0.01,'alpha_std':0.1,'gama_std':0.1},{'alpha_init':0.01,'gama_init':0.01,'alpha_std':0.1,'gama_std':0.1},{'gama_init':0.6,'gama_std':0.7}],
                   [{'alpha_init':0.01,'gama_init':0.01,'alpha_std':0.1,'gama_std':0.1},{'alpha_init':0.01,'gama_init':0.01,'alpha_std':0.1,'gama_std':0.1},{'gama_init':0.76,'gama_std':0.68}],
                   [{'alpha_init':0.01,'gama_init':0.01,'alpha_std':0.1,'gama_std':0.1},{'alpha_init':0.01,'gama_init':0.01,'alpha_std':0.1,'gama_std':0.1},{'gama_init':0.6,'gama_std':0.5}],
                   [{'alpha_init':0.01,'gama_init':0.01,'alpha_std':0.1,'gama_std':0.1},{'alpha_init':0.01,'gama_init':0.01,'alpha_std':0.1,'gama_std':0.1},{'gama_init':0.6,'gama_std':0.3}],
                   [{'alpha_init':0.01,'gama_init':0.01,'alpha_std':0.1,'gama_std':0.1}],
                   [{'alpha_init':0.01,'gama_init':0.01,'alpha_std':0.1,'gama_std':0.1}]]
        if(is_q):
            # As the DQ, we either don't quantize the input features of the REDDIT-BINARY dataset because the feature is only 1-dimension.
            self.conv1 = gin_layer(
                ResettableSequential(
                    QLinear(dataset.num_features,hidden_units, num_deg, bit,para_dict=para_list[0][0], all_positive=True,
                            quant_fea=True,
                            uniform=uniform,init=init),
                    relu(),
                    QLinear(hidden_units, hidden_units, num_deg, bit, para_dict=para_list[0][1],all_positive=True,
                            uniform=uniform,init=init),
                    relu(),
                ),
                train_eps=True,
                in_features=num_deg, out_features=1,
                bit=bit, para_dict=para_list[0][2],quant_fea=True,uniform=uniform
            )
        else:
            self.conv1 = GINConv(
                nn.Sequential(
                    nn.Linear(dataset.num_features, hidden_units),
                    nn.ReLU(),
                    nn.Linear(hidden_units, hidden_units),
                    nn.ReLU(),
                    nn.BatchNorm1d(hidden_units),
                ),
                train_eps=True,
            )
        self.convs = torch.nn.ModuleList()
        for i in range(num_layers - 1):
            if(is_q):
                self.convs.append(
                    gin_layer(
                        ResettableSequential(
                            QLinear(hidden_units, hidden_units, num_deg,bit, para_dict=para_list[0][0],all_positive=False,
                                    uniform=uniform,init=init),
                            relu(),
                            QLinear(hidden_units, hidden_units, num_deg,bit, para_dict=para_list[0][1], all_positive=True,
                                    uniform=uniform,init=init),
                            relu(),
                        ),
                        train_eps=True,
                        in_features=num_deg, out_features=hidden_units,
                        bit=bit, para_dict=para_list[0][2], uniform=uniform,quant_fea=True
                    )
                )
            else:
                self.convs.append(
                    GINConv(
                        nn.Sequential(
                            nn.Linear(hidden_units, hidden_units),
                            nn.ReLU(),
                            nn.Linear(hidden_units, hidden_units),
                            nn.ReLU(),
                            nn.BatchNorm1d(hidden_units),
                        ),
                        train_eps=True,
                    )
                )
        self.bn_list = torch.nn.ModuleList()
        for i in range(num_layers):
            self.bn_list.append(nn.BatchNorm1d(hidden_units))
        if(is_q):
            self.lin1 = QLinear(hidden_units, hidden_units, num_deg, bit, para_dict=para_list[-1][0], all_positive=False,
                                        uniform=uniform,init=init)
            self.lin2 = QLinear(hidden_units, dataset.num_classes, num_deg, bit, para_dict=para_list[-1][0], all_positive=True,
                                        uniform=uniform,init=init)
        else:
            self.lin1 = nn.Linear(hidden_units, hidden_units)
            self.lin2 = nn.Linear(hidden_units, dataset.num_classes)

    def reset_parameters(self):
        self.conv1.reset_parameters()
        for conv in self.convs:
            conv.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        bit_sum=x.new_zeros(1)
        x,bit_sum = self.conv1(x, edge_index,bit_sum)
 
        x = self.bn_list[0](x)
        # x,_,bit_sum = self.embeding(x,edge_index,bit_sum)
        # x = F.relu(x)
        i = 1
        for conv in self.convs:
            x,bit_sum = conv(x,edge_index,bit_sum)
            x = self.bn_list[i](x)
            i=i+1
        x = global_mean_pool(x, batch)

        x,_,bit_sum = self.lin1(x,edge_index,bit_sum)
 
        x = F.relu(x)
        x,_,bit_sum = self.lin2(x,edge_index,bit_sum)
    
        return F.log_softmax(x, dim=-1),bit_sum



# Helpful Function

In [23]:
class NormalizedDegree(object):
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, data):
        deg = degree(data.edge_index[0], dtype=torch.float)
        deg = (deg - self.mean) / self.std
        data.x = deg.view(-1, 1)
        return data


def num_graphs(data):
    if data.batch is not None:
        return data.num_graphs
    else:
        return data.x.size(0)


def train(model, optimizer, loader,a_loss, a_storage=1):
    model.train()
    total_loss = 0
    for data in loader:
        optimizer.zero_grad()
        data = data.to(device)
        out,bit_sum = model(data)
        loss = F.cross_entropy(out, data.y.view(-1))
        loss_store = a_loss*F.relu(bit_sum-a_storage)**2
        loss_store.backward(retain_graph=True)
        loss.backward()
        total_loss += loss.item() * num_graphs(data)
        optimizer.step()
    return total_loss / len(loader.dataset)


def eval_acc(model, loader):
    model.eval()

    correct = 0
    for data in loader:
        data = data.to(device)
        with torch.no_grad():
            pred = model(data)[0].max(1)[1]
        correct += pred.eq(data.y.view(-1)).sum().item()
    return correct / len(loader.dataset)


def eval_loss(model, loader):
    model.eval()

    loss = 0
    for data in loader:
        data = data.to(device)
        with torch.no_grad():
            out = model(data)[0]
        loss += F.cross_entropy(out, data.y.view(-1), reduction="sum").item()
    return loss / len(loader.dataset)

def k_fold(dataset, folds):
    skf = StratifiedKFold(folds, shuffle=True, random_state=12345)

    test_indices, train_indices = [], []
    for _, idx in skf.split(torch.zeros(len(dataset)), dataset.data.y):
        test_indices.append(torch.from_numpy(idx))

    val_indices = [test_indices[i - 1] for i in range(folds)]

    for i in range(folds):
        train_mask = torch.ones(len(dataset), dtype=torch.bool)
        train_mask[test_indices[i]] = 0
        train_mask[val_indices[i]] = 0
        train_indices.append(train_mask.nonzero().view(-1))

    return train_indices, test_indices, val_indices



def load_checkpoint(model, checkpoint):
    if checkpoint != 'No':
        print("loading checkpoint...")
        model_dict = model.state_dict()
        modelCheckpoint = torch.load(checkpoint)
        pretrained_dict = modelCheckpoint['state_dict']
        new_dict = {k: v for k, v in pretrained_dict.items() if ((k in model_dict.keys()))}
        model_dict.update(new_dict)
        print('Total : {}, update: {}'.format(len(pretrained_dict), len(new_dict)))
        model.load_state_dict(model_dict)
        print("loaded finished!")
    return model


## Definition of Requirment Parameters as args

In [24]:
import sys
import argparse

# Clearing the arguments
sys.argv = ['']


parser = argparse.ArgumentParser()
parser.add_argument('--model',type=str,default='GIN')
parser.add_argument('--gpu_id',type=int,default=0)
parser.add_argument('--dataset_name',type=str,default='PROTEINS')
parser.add_argument('--num_deg',type=int,default=1000)
parser.add_argument('--num_layers', type=int, default=5)
parser.add_argument('--hidden_units',type=int,default=64)
parser.add_argument('--batch-size',type=int,default=128)
parser.add_argument('--bit',type=int,default=4)
parser.add_argument('--max_epoch',type=int,default=100)
parser.add_argument('--max_cycle',type=int,default=2000)
parser.add_argument('--folds',type=int,default=5)
parser.add_argument('--weight_decay',type=float,default=0)
parser.add_argument('--lr',type=float,default=0.01)
parser.add_argument('--a_loss',type=float,default=0.001)
parser.add_argument('--lr_quant_scale_fea',type=float,default=0.02)
parser.add_argument('--lr_quant_scale_xw',type=float,default=1e-2)
parser.add_argument('--lr_quant_scale_weight',type=float,default=0.02)
parser.add_argument('--lr_quant_bit_fea',type=float,default=0.008)
parser.add_argument('--lr_quant_bit_weight',type=float,default=0.0001)
parser.add_argument('--lr_step_size',type=int, default=50)
parser.add_argument('--lr_decay_factor',type=float,default=0.5)
parser.add_argument('--lr_schedule_patience',type=int,default=10)
parser.add_argument('--is_naive',type=bool,default=False)
###############################################################
parser.add_argument('--resume',type=bool,default=True)
parser.add_argument('--store_ckpt',type=bool,default=True)
parser.add_argument('--uniform',type=bool,default=True)
parser.add_argument('--use_norm_quant',type=bool,default=True)
###############################################################
# The target memory size of nodes features
parser.add_argument('--a_storage',type=float,default=1)
# Path to results
parser.add_argument('--result_folder',type=str,default='result')
# Path to checkpoint
parser.add_argument('--check_folder',type=str,default='checkpoint')
# Path to dataset
parser.add_argument('--pathdataset',type=str,default='/')

args = parser.parse_args()
print(args)

Namespace(model='GIN', gpu_id=0, dataset_name='PROTEINS', num_deg=1000, num_layers=5, hidden_units=64, batch_size=128, bit=4, max_epoch=100, max_cycle=2000, folds=5, weight_decay=0, lr=0.01, a_loss=0.001, lr_quant_scale_fea=0.02, lr_quant_scale_xw=0.01, lr_quant_scale_weight=0.02, lr_quant_bit_fea=0.008, lr_quant_bit_weight=0.0001, lr_step_size=50, lr_decay_factor=0.5, lr_schedule_patience=10, is_naive=False, resume=True, store_ckpt=True, uniform=True, use_norm_quant=True, a_storage=1, result_folder='result', check_folder='checkpoint', pathdataset='/')


In [25]:
###############################################################
model = args.model
dataset_name = args.dataset_name
num_layers = args.num_layers
hidden_units=args.hidden_units
bit=args.bit
max_epoch = args.max_epoch
resume = args.resume


# Path direction
pathresult = args.result_folder+'/'+args.model+'_'+dataset_name
pathcheck = args.check_folder+'/'+args.model+'_'+dataset_name
if not os.path.exists(pathresult):
    os.makedirs(pathresult)
if not os.path.exists(pathcheck):
    os.makedirs(pathcheck)
###############################################################


## Loading Dataset and Normalization

In [26]:
def get_dataset(dataset_dir, dataset_name):
    dataset = TUDataset(dataset_dir, dataset_name)
    
    if dataset.data.x is None:
        max_degree = 0
        degs = []
        for data in dataset:
            degs += [degree(data.edge_index[0], dtype=torch.long)]
            max_degree = max(max_degree, degs[-1].max().item())

        if max_degree < 1000:
            dataset.transform = T.OneHotDegree(max_degree)
        else:
            deg = torch.cat(degs, dim=0).to(torch.float)
            mean, std = deg.mean().item(), deg.std().item()
            dataset.transform = NormalizedDegree(mean, std)
            
    return dataset                

In [27]:
dataset=get_dataset(args.pathdataset, args.dataset_name)



## Training Process

In [44]:
def analysis_bit_proteins(dataset, state_dict, all_positive=True):
    # Step 1: Collect layer-wise bit parameters
    layer_bits = {}
  
    for key, param in state_dict.items():
        if'quant' in key and 'bit' in key and 'fea' in key:
            layer_name = key.split('.quant_bit_fea')[0]
            layer_bits[layer_name] = param.abs().round() - 1

    # Step 2: Per-graph analysis
    for i, data in enumerate(dataset):
        #print(f"\n===== Analyzing Graph {i+1}/{len(dataset)} =====")
        edge_index = data.edge_index
        row, col = edge_index
        deg = degree(col, data.x.size(0)).cpu()
        
        # Step 3: Per-layer analysis within current graph
        for layer_name, bits in layer_bits.items():
            # Skip if bits tensor doesn't match current graph size
            if bits.size(0) != deg.size(0):
                continue
                
            print(f"\nLayer {layer_name}:")
            print(f"Avg bits: {bits.mean().item():.2f}")
            
            # Bit-degree correlation
            for bit_val in range(0, 9):
                mask = (bits == bit_val)
                if mask.sum() > 0:
                    avg_deg = deg[mask].mean().item()
                    print(f"  {bit_val}-bit nodes: {mask.sum().item()} nodes, Avg Degree={avg_deg:.1f}")
    
    # Step 4: Weight quantization analysis
    weight_bits = []
    for key, param in state_dict.items():
        if'quant' in key and 'bit' in key and 'fea' in key:
            bits = param.abs().round() - 1
            weight_bits.append(bits.mean().item())
    
    print("\n===== Weight Quantization Summary =====")
    if weight_bits:
        print(f"Avg weight bits: {sum(weight_bits)/len(weight_bits):.2f}")
    else:
        print("No weight quantization parameters found")
    
    print("Analysis complete")
    return {sum(weight_bits)/len(weight_bits)}

In [56]:
def run(bit=32, max_epoch=5):
        args.batch_size=16
        args.max_epoch=5
        args.batch_size=32
        max_acc =0.5
        
        if bit== 32:
            qypte='FP32'
        elif bit== 4:
            qypte= 'INT4'    
        elif bit== 8:
            qypte = 'INT8'
        
     
      
        val_losses, accu, durations = [], [], []
        quant_model_accuracy=[]
        quant_model_loss=[]
        t_quant_model=[]
        Num_parm_quant_model=[]
        quant_model_size=[]
        quant_energy_consumption=[]
        quant_cpu_usage=[]
        quant_memory_usage=[]
       
 
        #Eva= OrderedDict()
        #Eva=dict()
        # Initialize a dictionary to store all results per iteration
        Eva_iter = {
            "val losses per iter": [],
            "durations per iter": [],
            "quant model accuracy per iter": [],
            "time inference of quant model per iter": [],
            "number parmameters of quant model per iter": [],  # Store the best accuracy for each fold
            "size of quant model per iter": [],
            "energy consumption of quant model per iter": [],
            "cpu usage of quant model per iter": [],
            "total memory usage of quant model per iter": [],
            "final_metrics": {}  # Store final metrics (mean, std, etc.)
        }
    
        
     


        for fold, (train_idx, test_idx, val_idx) in enumerate(zip(*k_fold(dataset, args.folds))):
            print_max_acc=0
            train_dataset = dataset[train_idx.tolist()]
            test_dataset = dataset[test_idx.tolist()]
            val_dataset = dataset[val_idx.tolist()]
            train_loader = DataLoader(train_dataset, args.batch_size, num_workers=0,shuffle=False, drop_last=True)
            val_loader = DataLoader(val_dataset, args.batch_size,num_workers=0,shuffle=False,drop_last=True)
            test_loader = DataLoader(test_dataset, args.batch_size,num_workers=0,shuffle=False, drop_last=True)
            k=0


            model=qGIN(train_dataset, args.num_layers,hidden_units=args.hidden_units,bit=args.bit, is_q=True,
                    num_deg=args.num_deg,
                    uniform=args.uniform).to(device)
            weight_paras,quant_paras_bit_weight, quant_paras_bit_fea, quant_paras_scale_weight, quant_paras_scale_fea, quant_paras_scale_xw, quant_paras_bit_xw, other_paras = paras_group(model)
            # quant_paras_bit.requires_grad = False
            optimizer = torch.optim.Adam([{'params':weight_paras},
                                        {'params':quant_paras_scale_weight,'lr':args.lr_quant_scale_weight,'weight_decay':0},
                                        {'params':quant_paras_scale_fea,'lr':args.lr_quant_scale_fea,'weight_decay':0},
                                        {'params':quant_paras_scale_xw,'lr':args.lr_quant_scale_xw,'weight_decay':0},
                                        # {'params':quant_paras_bit_weight,'lr':args.lr_quant_bit_weight,'weight_decay':0},
                                        {'params':quant_paras_bit_fea,'lr':args.lr_quant_bit_fea,'weight_decay':0},
                                        {'params':other_paras}],
                                        lr=args.lr, weight_decay=args.weight_decay)
            scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_decay_factor)
            
            t_start = time.perf_counter()

            Eva_fold= OrderedDict() #It is a dictionary to arrange output of this fold
            
            # Remove saved model from previous fold
            files = glob.glob('*.pth.tar')
            for f in files:
                 #if f!='{}.pth'.format(best_model):
                        os.remove(f)
            for epoch in range(max_epoch):
                #t = tqdm(epoch)
                train_loss=0
                train_loss = train(model,optimizer,train_loader,args.a_loss, args.a_storage)
                start = time.process_time()
                val_loss = eval_loss(model,val_loader)
                val_losses.append(val_loss)
                end = time.process_time()
                acc = eval_acc(model,test_loader)
                

                if epoch % 50 == 0:
                    print(f"Eval Epoch: {epoch} |Val_loss:{val_loss:.03f}| Train_Loss: {train_loss:.3f} | Acc: {acc:.3f}|Fold: {fold}")
                accu.append(acc)
                if(acc>max_acc):
                    max_acc = acc
                    path=pathcheck+'/'+args.model+'_'+dataset_name+'_'+str(bit)+'bit'+'quantized.pth.tar'
                    torch.save({'state_dict': model.state_dict(), 'best_accu': acc,}, path)
                if(acc>print_max_acc):
                    print_max_acc = acc

            t_end = time.perf_counter()
            durations.append(t_end - t_start)      
                    
            # Start monitoring CPU and memory usage, model size, number of parametes, time inference and  power consumption
            quant_model_path=pathcheck+'/'+args.model+'_'+dataset_name+'_'+str(bit)+'bit'+'quantized.pth.tar'
            #state = torch.load(quant_model_path)
            #dict=state['state_dict']
            #recover_model = lambda: model.load_state_dict(state['state_dict'])

            gc.collect()
            time.sleep(5)  # Add a 5-second delay to stabilize the initial state
            tracemalloc.start()  # Start tracking memory allocations
            snapshot_before = tracemalloc.take_snapshot()#take a snapshot of the current memory state before starting the measurement.

            t0 = time.perf_counter()
            initial_cpu_usage = get_cpu_usage()
            power_usage = estimate_power_usage(initial_cpu_usage)


            fold_quant_model_accuracy= eval_acc(model, test_loader)

            fold_quant_cpu_usage = get_cpu_usage()
            t1 = time.perf_counter()
            fold_t_quant_model=t1-t0

            snapshot_after = tracemalloc.take_snapshot()
            tracemalloc.stop()
            top_stats = snapshot_after.compare_to(snapshot_before, 'lineno')

            folde_quant_total_memory_diff = sum([stat.size_diff for stat in top_stats])
            fold_quant_energy_consumption = power_usage * fold_t_quant_model
            #fold_quant_model_size = os.path.getsize(main_model_path)
            fold_quant_model_size =calculate_model_size(model, qypte )
            fold_num_parm_quant_model=get_num_parameters(model, count_nonzero_only=True)
            

            gc.collect()
            time.sleep(5) 
            #Update Eva dictionary
            Eva_fold.update({'quant model accuracy per fold': fold_quant_model_accuracy,
                        'time inference of quant model per fold':fold_t_quant_model,
                        'number parmameters of quant model per fold': fold_num_parm_quant_model,
                        'size of quant model per fold': fold_quant_model_size, 
                        'energy consumption of quant model per fold':fold_quant_energy_consumption,
                        'total memory usage of quant model per fold':folde_quant_total_memory_diff,
                        'cpu usage of quant model per fold':fold_quant_cpu_usage
                       })

            gc.collect()
            time.sleep(5) 


            quant_model_accuracy.append(Eva_fold['quant model accuracy per fold'])
            t_quant_model.append(Eva_fold['time inference of quant model per fold'])
            Num_parm_quant_model.append(int(Eva_fold['number parmameters of quant model per fold']))
            quant_model_size.append(int(Eva_fold['size of quant model per fold']))
            quant_energy_consumption.append(Eva_fold['energy consumption of quant model per fold'])
            quant_cpu_usage.append(Eva_fold['cpu usage of quant model per fold'])
            quant_memory_usage.append(Eva_fold['total memory usage of quant model per fold'])
           # analysis_bit_proteins(dataset, model.state_dict(), all_positive=True)





        Eva_iter["quant model accuracy per iter"]= stat.mean(quant_model_accuracy)
        Eva_iter["time inference of quant model per iter"]= stat.mean(t_quant_model)
        Eva_iter["number parmameters of quant model per iter"]=  stat.mean(Num_parm_quant_model)
        Eva_iter["size of quant model per iter"]= stat.mean(quant_model_size)
        Eva_iter["energy consumption of quant model per iter"]= stat.mean(quant_energy_consumption)
        Eva_iter["cpu usage of quant model per iter"]= stat.mean(quant_cpu_usage)
        Eva_iter["total memory usage of quant model per iter"]= stat.mean(quant_memory_usage)


        loss, acc, duration = tensor(val_losses), tensor(accu), tensor(durations)
        loss, acc = loss.view(args.folds, max_epoch), acc.view(args.folds, max_epoch)
        loss, argmin = loss.min(dim=1)
        acc = acc[torch.arange(args.folds, dtype=torch.long), argmin]

        Eva_iter["val losses per iter"]= loss.mean().item()
        Eva_iter["durations per iter"]= duration.mean().item()


        state = torch.load(quant_model_path)
        dict=state['state_dict']
     

        return Eva_iter , model                                       


### Manual Measurement

In [57]:
# The following are all list of criteria for measurements. 
# We collect all desired datas of each list across iterations. 
# Then, we compute average and std of each list.



#quant model
Quant_val_loss=[]
Quant_duration=[]
Quant_model_accuracy=[]
T_quant_model=[]
Num_parm_quant_model=[]
Quant_model_size=[]
Quant_Energy_Consumption=[]
Quant_Cpu_Usage=[]
Quant_Memory_Usage=[]


# Here is the dictionary to record the list of all measurements
Eva_measure={'quant validation loss':Quant_val_loss,
             'quant duration':Quant_duration,
            'quant model accuracy': Quant_model_accuracy,
            'time inference of quant model':T_quant_model,
            'number parmameters of quant model':Num_parm_quant_model,
            'quant model size':Quant_model_size,
            'energy consumption of quant model':Quant_Energy_Consumption,
            'cpu usage of quant model':Quant_Cpu_Usage,
            'memory usage of quant model':Quant_Memory_Usage}


In [58]:
args.max_epoch=50
max_epoch = args.max_epoch
iterations=1
args.bit=8
bit=args.bit
folds=10

In [59]:
#### load the quantized  model

for i in range(iterations):
    print('********************************************')
    print(f'The iteration is :{i+1} ')
   
 

    
    Eva_iter,model=run(bit, max_epoch)

 
    Quant_val_loss.append(Eva_iter["val losses per iter"])
    Quant_duration.append(Eva_iter["durations per iter"])
    Quant_model_accuracy.append(Eva_iter["quant model accuracy per iter"])
    T_quant_model.append(Eva_iter["time inference of quant model per iter"])
    Num_parm_quant_model.append(Eva_iter["number parmameters of quant model per iter"])
    Quant_model_size.append(Eva_iter["size of quant model per iter"])
    Quant_Energy_Consumption.append(Eva_iter["energy consumption of quant model per iter"])
    Quant_Cpu_Usage.append( Eva_iter["cpu usage of quant model per iter"])
    Quant_Memory_Usage.append(Eva_iter["total memory usage of quant model per iter"])
 

********************************************
The iteration is :1 




Eval Epoch: 0 |Val_loss:0.903| Train_Loss: 1.642 | Acc: 0.596|Fold: 0
Eval Epoch: 0 |Val_loss:1.950| Train_Loss: 1.832 | Acc: 0.596|Fold: 1
Eval Epoch: 0 |Val_loss:0.795| Train_Loss: 2.060 | Acc: 0.596|Fold: 2
Eval Epoch: 0 |Val_loss:0.829| Train_Loss: 1.781 | Acc: 0.595|Fold: 3
Eval Epoch: 0 |Val_loss:1.949| Train_Loss: 1.866 | Acc: 0.595|Fold: 4


In [60]:
# 50 Epoch-bit =8
quant_model_path=pathcheck+'/'+args.model+'_'+dataset_name+'_'+str(bit)+'bit'+'quantized.pth.tar'
state = torch.load(quant_model_path)
dict=state['state_dict']
analysis_bit_proteins(dataset, model.state_dict(), all_positive=True)


===== Weight Quantization Summary =====
Avg weight bits: 6.08
Analysis complete


{6.084250013033549}

In [70]:
## Save model by Danny bit=8
bit=4
quant_model_path= args.model+'_'+dataset_name+'_'+str(bit)+'bit'+'quantized.pth.tar'
print(quant_model_path)
state = torch.load(quant_model_path,map_location=torch.device('cpu'))
dict=state['state_dict']
analysis_bit_proteins(dataset, model.state_dict(), all_positive=True)

GIN_PROTEINS_4bitquantized.pth.tar

===== Weight Quantization Summary =====
Avg weight bits: 6.08
Analysis complete


{6.084250013033549}

In [69]:
## Save model by Danny bit=4


args.bit=4
bit=args.bit
quant_model_path= args.model+'_'+dataset_name+'_'+str(bit)+'bit'+'quantized.pth.tar'
print(quant_model_path)
state = torch.load(quant_model_path,map_location=torch.device('cpu'))
dict=state['state_dict']
analysis_bit_proteins(dataset, model.state_dict(), all_positive=True)

GIN_PROTEINS_4bitquantized.pth.tar

===== Weight Quantization Summary =====
Avg weight bits: 6.08
Analysis complete


{6.084250013033549}

In [71]:
# 20 Epoch-bit =8
bit =8
quant_model_path=pathcheck+'/'+args.model+'_'+dataset_name+'_'+str(bit)+'bit'+'quantized.pth.tar'
print(quant_model_path)
state = torch.load(quant_model_path)
dict=state['state_dict']
analysis_bit_proteins(dataset, model.state_dict(), all_positive=True)

checkpoint/GIN_PROTEINS/GIN_PROTEINS_8bitquantized.pth.tar

===== Weight Quantization Summary =====
Avg weight bits: 6.08
Analysis complete


{6.084250013033549}

In [74]:
# epoch=2, bit=8
bit=8
model=qGIN(dataset, args.num_layers,hidden_units=args.hidden_units,bit=args.bit, is_q=True,
                    num_deg=args.num_deg,
                    uniform=args.uniform).to(device)
quant_model_path=pathcheck+'/'+args.model+'_'+dataset_name+'_'+str(bit)+'bit'+'quantized.pth.tar'
state = torch.load(quant_model_path)
dict=state['state_dict']
analysis_bit_proteins(dataset, model.state_dict(), all_positive=True)


===== Weight Quantization Summary =====
Avg weight bits: 3.58
Analysis complete


{3.5833333333333335}

In [49]:
# epoch=2, bit=4
quant_model_path=pathcheck+'/'+args.model+'_'+dataset_name+'_'+str(bit)+'bit'+'quantized.pth.tar'
state = torch.load(quant_model_path)
dict=state['state_dict']
analysis_bit_proteins(dataset, model.state_dict(), all_positive=True)


===== Weight Quantization Summary =====
Avg weight bits: 3.58
Analysis complete


{3.5833333333333335}

In [None]:
# This is a dictionary to save all measurements. Aftre measuring, we can compute mean and std of each item.
from collections import OrderedDict 
Eva_final = OrderedDict()



quant_model_val_loss_mean =stat.mean(Quant_val_loss)
quant_model_val_loss_std = stat.stdev(Quant_val_loss)
#desc = "{:.3f} ± {:.3f}".format(acc_mean,acc_std)
Eva_final.update({'Ave of quant loss validation':float(format(quant_model_val_loss_mean, '.3f'))})
Eva_final.update({'Std of quant loss validation':float(format(quant_model_val_loss_std, '.3f'))})    

quant_model_duration_mean =stat.mean(Quant_duration)
quant_model_duration_std = stat.stdev(Quant_duration)
#desc = "{:.3f} ± {:.3f}".format(acc_mean,acc_std)
Eva_final.update({'Ave of quant model duration':float(format(quant_model_duration_mean , '.3f'))})
Eva_final.update({'Std of quant model duration':float(format(quant_model_duration_std, '.3f'))})                                         
                                     

quant_model_accuracy_mean =stat.mean(Quant_model_accuracy)
quant_model_accuracy_std = stat.stdev(Quant_model_accuracy)
#desc = "{:.3f} ± {:.3f}".format(acc_mean,acc_std)
Eva_final.update({'Ave of quant model accuracy':float(format(quant_model_accuracy_mean, '.3f'))})
Eva_final.update({'Std of quant model accuracy':float(format(quant_model_accuracy_std, '.3f'))})
                 

t_quant_model_mean = stat.mean(T_quant_model)
t_quant_model_std =stat.stdev(T_quant_model)
#desc = "{:.3f} ± {:.3f}".format(acc_mean,acc_std)
Eva_final.update({'Ave of time inference of quant model':float(format(t_quant_model_mean, '.3f'))})
Eva_final.update({'Std of time inference of quant model':float(format(t_quant_model_std, '.3f'))})

num_parm_quant_model_mean = stat.mean(Num_parm_quant_model)
num_parm_quant_model_std = stat.stdev(Num_parm_quant_model)
#desc = "{:.3f} ± {:.3f}".format(acc_mean,acc_std)
Eva_final.update({'Ave of number parmameters of quant model':num_parm_quant_model_mean})
Eva_final.update({'Std of number parmameters of quant model':num_parm_quant_model_std})

quant_model_size_mean =stat.mean( Quant_model_size)
quant_model_size_std = stat.stdev(Quant_model_size)
#desc = "{:.3f} ± {:.3f}".format(acc_mean,acc_std)
Eva_final.update({'Ave of quant model size':quant_model_size_mean})
Eva_final.update({'Std of quant_model_size':quant_model_size_std })

quant_energy_consumption_mean = stat.mean(Quant_Energy_Consumption)
quant_energy_consumption_std = stat.stdev(Quant_Energy_Consumption)
#desc = "{:.3f} ± {:.3f}".format(acc_mean,acc_std)
Eva_final.update({'Ave of energy consumption of quant model':quant_energy_consumption_mean })
Eva_final.update({'Std of energy consumption of quant model':quant_energy_consumption_std})


quant_cpu_usage_mean = stat.mean(Quant_Cpu_Usage)
quant_cpu_usage_std = stat.stdev(Quant_Cpu_Usage)
#desc = "{:.3f} ± {:.3f}".format(acc_mean,acc_std)
Eva_final.update({'Ave of cpu usage of quant model':quant_cpu_usage_mean})
Eva_final.update({'Std of cpu usage of quant model':quant_cpu_usage_std})

quant_memory_usage_mean = stat.mean(Quant_Memory_Usage)
quant_memory_usage_std = stat.stdev(Quant_Memory_Usage)
#desc = "{:.3f} ± {:.3f}".format(acc_mean,acc_std)
Eva_final.update({'Ave of memory usage of quant model':quant_memory_usage_mean})
Eva_final.update({'Std of memory usage of quant model':quant_memory_usage_std})

#################################


# Determing Quantization Method 

print(f"All measurement about A-A-Q Quantization process of type:{ bit} ")   
Eva_final

In [41]:
# Record the accuracy
Quantization_Method='AAQ'

file_name = pathresult+'/'+Quantization_Method+'Method'+'_On'+args.model+'_'+dataset_name+'_'+str(bit)+'bit'+'.txt'

with open(file_name, 'w') as f:
    for key, value in vars(args).items():
        f.write('%s:%s\n'%(key, value))

    for key, value in Eva_final.items():
        f.write('%s:%s\n'%(key, value))

    for key, value in Eva_measure.items():
        f.write('%s:%s\n' % (key, ','.join(map(str, value))))    

In [26]:
import inspect
from collections import OrderedDict


import torch
from torch.nn import Parameter, Module, ModuleDict
import torch.nn.functional as F
from torch_geometric.utils import (
    softmax,
    add_self_loops,
    remove_self_loops,
    add_remaining_self_loops,
    degree,
)
import torch_scatter
from torch_scatter import scatter_add
from torch_geometric.nn.inits import glorot, zeros
from torch_sparse import SparseTensor, matmul, fill_diag, sum as sparsesum, mul


def scatter_(name, src, index, dim=0, dim_size=None):
    """Taken from an earlier version of PyG"""
    assert name in ["add", "mean", "min", "max"]

    op = getattr(torch_scatter, "scatter_{}".format(name))
    out = op(src, index, dim, None, dim_size)
    out = out[0] if isinstance(out, tuple) else out

    if name == "max":
        out[out < -10000] = 0
    elif name == "min":
        out[out > 10000] = 0

    return out

def analysis_bit(data, state_dict,all_positive=True,name='plat'):
    mean_all = []
    if(name=='ogbn-arxiv'):
        adj_t = data.adj_t
        adj_t = adj_t.fill_value(1.,)
        deg = sparsesum(adj_t, dim=1)
    else:
        edge_index = data.edge_index
        row,col = edge_index
        deg = degree(col,data.x.size(0))
    for key in state_dict.keys():
        if ('quant' in key and 'bit' in key and 'fea' in key):
            if(all_positive):
                bit=state_dict[key].abs().round()-1
            else:
                bit=state_dict[key].abs().round()-1
            print(key+'\n')
            mean_all.append(bit.mean())
            print('The average bits of current layer:',bit.mean())
            print("0bit:{}".format((bit==0).sum()))
            print("1bit:{}".format((bit==1).sum()))
            print("2bit:{}".format((bit==2).sum()))
            print("3bit:{}".format((bit==3).sum()))
            print("4bit:{}".format((bit==4).sum()))
            print("5bit:{}".format((bit==5).sum()))
            print("6bit:{}".format((bit==6).sum()))
            print("7bit:{}".format((bit==7).sum()))
            print("8bit:{}".format((bit==8).sum()))
            print("9bit:{}".format((bit==9).sum()))
            print('\n')
            print('The average degree of the nodes using corresponding bitwidth:')
            index_1_bit = torch.where(bit==1)[0]
            index_2_bit = torch.where(bit==2)[0]
            index_3_bit = torch.where(bit==3)[0]
            index_4_bit = torch.where(bit==4)[0]
            index_5_bit = torch.where(bit==5)[0]
            index_6_bit = torch.where(bit==6)[0]
            index_7_bit = torch.where(bit==7)[0]
            index_8_bit = torch.where(bit==8)[0]
            print('1bit_deg_mean:',deg[index_1_bit].mean())
            print('2bit_deg_mean:',deg[index_2_bit].mean())
            print('3bit_deg_mean:',deg[index_3_bit].mean())
            print('4bit_deg_mean:',deg[index_4_bit].mean())
            print('5bit_deg_mean:',deg[index_5_bit].mean())
            print('6bit_deg_mean:',deg[index_6_bit].mean())
            print('7bit_deg_mean:',deg[index_7_bit].mean())
            print('8bit_deg_mean:',deg[index_8_bit].mean())
            print('\n')
    print('The average bits: ',sum(mean_all)/len(mean_all))
    print('Finish')

msg_special_args = set(
    [
        "edge_index",
        "edge_index_i",
        "edge_index_j",
        "size",
        "size_i",
        "size_j",
    ]
)

aggr_special_args = set(
    [
        "index",
        "dim_size",
    ]
)

update_special_args = set([])
