In [20]:
# !pip install prettytable
import json
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
from prettytable import PrettyTable

import random

import torch
import torch.nn as nn
import torch.nn.functional as F

from src.model.model_utils.network_util import (MLP, Aggre_Index, Gen_Index,
                                                build_mlp)
from src.model.transformer.attention import MultiHeadAttention


class GraphEdgeAttenNetwork(torch.nn.Module):
    
    def __init__(self, num_heads, dim_node, dim_edge, dim_atten, aggr= 'max', use_bn=False,
                 flow='target_to_source',attention = 'fat',use_edge:bool=True, **kwargs):
        super().__init__() #  "Max" aggregation.
        self.name = 'edgeatten'
        self.dim_node=dim_node
        self.dim_edge=dim_edge
        self.index_get = Gen_Index(flow=flow)
        if attention == 'fat':        
            self.index_aggr = Aggre_Index(aggr=aggr,flow=flow)
        elif attention == 'distance':
            aggr = 'add'
            self.index_aggr = Aggre_Index(aggr=aggr,flow=flow)
        else:
            raise NotImplementedError()

        self.edgeatten = MultiHeadedEdgeAttention(
            dim_node=dim_node,dim_edge=dim_edge,dim_atten=dim_atten,
            num_heads=num_heads,use_bn=use_bn,attention=attention,use_edge=use_edge, **kwargs)
        self.prop = build_mlp([dim_node+dim_atten, dim_node+dim_atten, dim_node],
                            do_bn= use_bn, on_last=False)

    def forward(self, x, edge_feature, edge_index, weight=None, istrain=False):
        assert x.ndim == 2
        assert edge_feature.ndim == 2
        x_i, x_j = self.index_get(x, edge_index)
        xx, gcn_edge_feature, prob = self.edgeatten(x_i, edge_feature, x_j, weight, istrain=istrain)
        xx = self.index_aggr(xx, edge_index, dim_size = x.shape[0])
        xx = self.prop(torch.cat([x,xx],dim=1))
        return xx, gcn_edge_feature
class MultiHeadedEdgeAttention(torch.nn.Module):
    def __init__(self, num_heads: int, dim_node: int, dim_edge: int, dim_atten: int, use_bn=False,
                 attention = 'fat', use_edge:bool = True, **kwargs):
        super().__init__()
        assert dim_node % num_heads == 0
        assert dim_edge % num_heads == 0
        assert dim_atten % num_heads == 0
        self.name = 'MultiHeadedEdgeAttention'
        self.dim_node=dim_node
        self.dim_edge=dim_edge
        self.d_n = d_n = dim_node // num_heads
        self.d_e = d_e = dim_edge // num_heads
        self.d_o = d_o = dim_atten // num_heads
        self.num_heads = num_heads
        self.use_edge = use_edge
        self.nn_edge = build_mlp([dim_node*2+dim_edge,(dim_node+dim_edge),dim_edge],
                          do_bn= use_bn, on_last=False)
        self.mask_obj = 0.5
        
        DROP_OUT_ATTEN = None
        if 'DROP_OUT_ATTEN' in kwargs:
            DROP_OUT_ATTEN = kwargs['DROP_OUT_ATTEN']
            # print('drop out in',self.name,'with value',DROP_OUT_ATTEN)
        
        self.attention = attention
        assert self.attention in ['fat']
        
        if self.attention == 'fat':
            if use_edge:
                self.nn = MLP([d_n+d_e, d_n+d_e, d_o],do_bn=use_bn,drop_out = DROP_OUT_ATTEN)
            else:
                self.nn = MLP([d_n, d_n*2, d_o],do_bn=use_bn,drop_out = DROP_OUT_ATTEN)
                
            self.proj_edge  = build_mlp([dim_edge,dim_edge])
            self.proj_query = build_mlp([dim_node,dim_node])
            self.proj_value = build_mlp([dim_node,dim_atten])
        elif self.attention == 'distance':
            self.proj_value = build_mlp([dim_node,dim_atten])

        
    def forward(self, query, edge, value, weight=None, istrain=False):
        batch_dim = query.size(0)
        
        edge_feature = torch.cat([query, edge, value],dim=1)
        # avoid overfitting by mask relation input object feature
        # if random.random() < self.mask_obj and istrain: 
        #     feat_mask = torch.cat([torch.ones_like(query),torch.zeros_like(edge), torch.ones_like(value)],dim=1)
        #     edge_feature = torch.where(feat_mask == 1, edge_feature, torch.zeros_like(edge_feature))
        
        edge_feature = self.nn_edge( edge_feature )#.view(b, -1, 1)

        if self.attention == 'fat':
            value = self.proj_value(value)
            query = self.proj_query(query).view(batch_dim, self.d_n, self.num_heads)
            edge = self.proj_edge(edge).view(batch_dim, self.d_e, self.num_heads)
            if self.use_edge:
                prob = self.nn(torch.cat([query,edge],dim=1)) # b, dim, head    
            else:
                prob = self.nn(query) # b, dim, head 
            prob = prob.softmax(1)
            x = torch.einsum('bm,bm->bm', prob.reshape_as(value), value)
        
        elif self.attention == 'distance':
            raise NotImplementedError()
        
        else:
            raise NotImplementedError('')
        
        return x, edge_feature, prob

class MMG(torch.nn.Module):

    def __init__(self, dim_node, dim_edge, dim_atten, num_heads=1, aggr= 'max', 
                 use_bn=False,flow='target_to_source', attention = 'fat', 
                 hidden_size=512, depth=1, use_edge:bool=True, **kwargs,
                 ):
        
        super().__init__()

        self.num_heads = num_heads
        self.depth = depth

        self.self_attn = nn.ModuleList(
            MultiHeadAttention(d_model=dim_node, d_k=dim_node // num_heads, d_v=dim_node // num_heads, h=num_heads) for i in range(depth))

        self.cross_attn = nn.ModuleList(
            MultiHeadAttention(d_model=dim_node, d_k=dim_node // num_heads, d_v=dim_node // num_heads, h=num_heads) for i in range(depth))

        self.cross_attn_rel = nn.ModuleList(
            MultiHeadAttention(d_model=dim_edge, d_k=dim_edge // num_heads, d_v=dim_edge // num_heads, h=num_heads) for i in range(depth))
        
        self.gcn_2ds = torch.nn.ModuleList()
        self.gcn_3ds = torch.nn.ModuleList()
        
        for _ in range(self.depth):

            self.gcn_2ds.append(GraphEdgeAttenNetwork(
                            num_heads,
                            dim_node,
                            dim_edge,
                            dim_atten,
                            aggr,
                            use_bn=use_bn,
                            flow=flow,
                            attention=attention,
                            use_edge=use_edge, 
                            **kwargs))
            
            self.gcn_3ds.append(GraphEdgeAttenNetwork(
                            num_heads,
                            dim_node,
                            dim_edge,
                            dim_atten,
                            aggr,
                            use_bn=use_bn,
                            flow=flow,
                            attention=attention,
                            use_edge=use_edge, 
                            **kwargs))
           
        self.self_attn_fc = nn.Sequential(  # 11 32 32 4(head)
            nn.Linear(4, 32),  # xyz, dist
            nn.ReLU(),
            nn.LayerNorm(32),
            nn.Linear(32, 32),
            nn.ReLU(),
            nn.LayerNorm(32),
            nn.Linear(32, num_heads)
        )
        
        self.drop_out = torch.nn.Dropout(kwargs['DROP_OUT_ATTEN'])
    
    
    def forward(self, obj_feature_3d, obj_feature_2d, edge_feature_3d, edge_feature_2d, edge_index, batch_ids, obj_center=None, discriptor=None, istrain=False):

        # compute weight for obj
        if obj_center is not None:
            # get attention weight for object
            batch_size = batch_ids.max().item() + 1
            N_K = obj_feature_3d.shape[0]
            obj_mask = torch.zeros(1, 1, N_K, N_K).cuda()
            obj_distance_weight = torch.zeros(1, self.num_heads, N_K, N_K).cuda()
            count = 0

            for i in range(batch_size):

                idx_i = torch.where(batch_ids == i)[0]
                obj_mask[:, :, count:count + len(idx_i), count:count + len(idx_i)] = 1
            
                center_A = obj_center[None, idx_i, :].clone().detach().repeat(len(idx_i), 1, 1)
                center_B = obj_center[idx_i, None, :].clone().detach().repeat(1, len(idx_i), 1)
                center_dist = (center_A - center_B)
                dist = center_dist.pow(2)
                dist = torch.sqrt(torch.sum(dist, dim=-1))[:, :, None]
                weights = torch.cat([center_dist, dist], dim=-1).unsqueeze(0)  # 1 N N 4
                dist_weights = self.self_attn_fc(weights).permute(0,3,1,2)  # 1 num_heads N N
                
                attention_matrix_way = 'add'
                obj_distance_weight[:, :, count:count + len(idx_i), count:count + len(idx_i)] = dist_weights

                count += len(idx_i)
        else:
            obj_mask = None
            obj_distance = None
            attention_matrix_way = 'mul'


        for i in range(self.depth):

            obj_feature_3d = obj_feature_3d.unsqueeze(0)
            obj_feature_2d = obj_feature_2d.unsqueeze(0)
            
            obj_feature_3d = self.self_attn[i](obj_feature_3d, obj_feature_3d, obj_feature_3d, attention_weights=obj_distance_weight, way=attention_matrix_way, attention_mask=obj_mask, use_knn=False)
            obj_feature_2d = self.cross_attn[i](obj_feature_2d, obj_feature_3d, obj_feature_3d, attention_weights=obj_distance_weight, way=attention_matrix_way, attention_mask=obj_mask, use_knn=False)
            
            obj_feature_3d = obj_feature_3d.squeeze(0)
            obj_feature_2d = obj_feature_2d.squeeze(0)  


            obj_feature_3d, edge_feature_3d = self.gcn_3ds[i](obj_feature_3d, edge_feature_3d, edge_index, istrain=istrain)
            obj_feature_2d, edge_feature_2d = self.gcn_2ds[i](obj_feature_2d, edge_feature_2d, edge_index, istrain=istrain)

            
            edge_feature_2d = edge_feature_2d.unsqueeze(0)
            edge_feature_3d = edge_feature_3d.unsqueeze(0)
            
            edge_feature_2d = self.cross_attn_rel[i](edge_feature_2d, edge_feature_3d, edge_feature_3d, use_knn=False)
            
            edge_feature_2d = edge_feature_2d.squeeze(0)
            edge_feature_3d = edge_feature_3d.squeeze(0)

            if i < (self.depth-1) or self.depth==1:
                
                obj_feature_3d = F.relu(obj_feature_3d)
                obj_feature_3d = self.drop_out(obj_feature_3d)
                
                obj_feature_2d = F.relu(obj_feature_2d)
                obj_feature_2d = self.drop_out(obj_feature_2d)

                edge_feature_3d = F.relu(edge_feature_3d)
                edge_feature_3d = self.drop_out(edge_feature_3d)

                edge_feature_2d = F.relu(edge_feature_2d)
                edge_feature_2d = self.drop_out(edge_feature_2d)
        
        return obj_feature_3d, obj_feature_2d, edge_feature_3d, edge_feature_2d

def print_parameters(model, title="Model Parameters"):
    table = PrettyTable(["Layer", "Total Parameters"])
    total_params = 0

    for name, parameter in model.named_parameters():
        if not parameter.requires_grad:
            continue
        param_total = parameter.numel()
        if any(substring in name for substring in ['_orig', '_mask']):
            continue
        table.add_row([name, param_total])
        total_params += param_total
    
    print(title)
    print(table)
    print(f"Total Parameters: {total_params}")
    return total_params  


def prune_model(model, pruning_rate=0.2):
    parameters_to_prune = []
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            parameters_to_prune.append((module, 'weight'))
    
    prune.global_unstructured(
        parameters_to_prune,
        pruning_method=prune.L1Unstructured,
        amount=pruning_rate,
    )

dim_node = 512
dim_edge = 512      
dim_atten = 256 
num_heads = 8  
drop_out_atten = 0.1
model = MMG(dim_node, dim_edge, dim_atten, num_heads, DROP_OUT_ATTEN=drop_out_atten)

before_pruning_params = print_parameters(model, "Before Pruning")
prune_model(model, pruning_rate=0.5)
after_pruning_params = print_parameters(model, "After Pruning")

# Calculate and print the reduction in parameters
reduction = before_pruning_params - after_pruning_params
print(f"Reduction in parameters: {reduction} ({(reduction / before_pruning_params) * 100:.2f}%)")

# report_content = f"""
# Before Pruning: {before_pruning_params} parameters
# After Pruning: {after_pruning_params} parameters
# Reduction in parameters: {reduction} ({(reduction / before_pruning_params) * 100:.2f}%)
# """

# # 파일에 결과 저장
# with open("mmg_pruning.txt", "w") as file:
#     file.write(report_content)

Before Pruning
+-----------------------------------------+------------------+
|                  Layer                  | Total Parameters |
+-----------------------------------------+------------------+
|    self_attn.0.attention.fc_q.weight    |      262144      |
|     self_attn.0.attention.fc_q.bias     |       512        |
|    self_attn.0.attention.fc_k.weight    |      262144      |
|     self_attn.0.attention.fc_k.bias     |       512        |
|    self_attn.0.attention.fc_v.weight    |      262144      |
|     self_attn.0.attention.fc_v.bias     |       512        |
|    self_attn.0.attention.fc_o.weight    |      262144      |
|     self_attn.0.attention.fc_o.bias     |       512        |
|      self_attn.0.layer_norm.weight      |       512        |
|       self_attn.0.layer_norm.bias       |       512        |
|    cross_attn.0.attention.fc_q.weight   |      262144      |
|     cross_attn.0.attention.fc_q.bias    |       512        |
|    cross_attn.0.attention.fc_k.weight 

In [30]:
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
from prettytable import PrettyTable

from src.model.model_utils.network_util import build_mlp, Gen_Index, Aggre_Index, MLP
from src.model.model_utils.networks_base import BaseNetwork
from src.model.transformer.attention import MultiHeadAttention
import inspect
from collections import OrderedDict
import os
from src.utils import op_utils
from copy import deepcopy


def apply_pruning(model, pruning_rate=0.5, save_path="pruning.txt"):
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            prune.l1_unstructured(module, name='weight', amount=pruning_rate)
            prune.remove(module, 'weight')
            
    table = PrettyTable(["Layer", "Total Parameters", "Non-zero Parameters", "Sparsity (%)"])
    total_params = total_non_zero = 0
    
    for name, param in model.named_parameters():
        num_params = param.numel()
        non_zero_params = torch.count_nonzero(param).item()
        sparsity = 100.0 * (1 - non_zero_params / num_params)
        table.add_row([name, num_params, non_zero_params, f"{sparsity:.2f}"])
        total_params += num_params
        total_non_zero += non_zero_params
    
    total_sparsity = 100.0 * (1 - total_non_zero / total_params)
    table.add_row(["Total", total_params, total_non_zero, f"{total_sparsity:.2f}"])
    
    print(apply_pruning)
    # with open(save_path, "w") as f:
    #     f.write(str(table))

