In [49]:

import numpy as np
import torch.nn.functional as F
import torch.optim as optim
from deeprobust.graph.defense import GCN 
from deeprobust.graph.global_attack import DICE 
from deeprobust.graph.utils import *
from deeprobust.graph.data import Dataset
from deeprobust.graph.defense import GCNJaccard, GCNSVD, RGCN 
from scipy.sparse import csr_matrix
# from deeprobust.graph.defense.noisy_gcn import Noisy_GCN
from deeprobust.graph.defense.noisy_gcn_with_prune import Noisy_PGCN 
from torch_geometric.nn import GINConv, GATConv, GCNConv

import argparse

import warnings
warnings.filterwarnings("ignore")


In [None]:
import torch
import torch.nn as nn

class ConcatModel(nn.Module):
    def __init__(self, dim1, dim2, output_dim):
        super(ConcatModel, self).__init__()
        # 定義兩個學習參數
        self.param1 = nn.Parameter(torch.randn(dim1)) 
        self.param2 = nn.Parameter(torch.randn(dim2)) 
        
        # 定義全連接層
        self.fc = nn.Linear(dim1, output_dim) 

    def forward(self):
        # 將兩個參數合併
        # 通過全連接層
        output = self.fc(self.param1) + self.param2
        return output 

# 使用範例
model = ConcatModel(dim1=4, dim2=4, output_dim=4) 
output = model() 
print(output) 


tensor([-0.3042, -0.4692, -0.6858,  0.0013], grad_fn=<AddBackward0>)


In [62]:
import torch
from torchvision.models import resnet18
import torch_pruning as tp
import torch
import torch.nn as nn
from torch_pruning.pruner.importance import GroupMagnitudeImportance
from torch_geometric.nn import GCNConv
from torch_pruning.pruner.function import BasePruningFunc        # 型別提示用

example_inputs = []

def forward_fn(model, inputs): 
    return model(*inputs)

importance = GroupMagnitudeImportance(p=2, group_reduction="mean", normalizer="mean")
# 2. Initialize a pruner with the model and the importance criterion
ignored_layers = []
for m in model.modules():
    if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
        ignored_layers.append(m) # DO NOT prune the final classifier!

pruner = tp.pruner.BasePruner( # We can always choose BasePruner if sparse training is not required.
    model,
    example_inputs, 
    importance=importance, 
    pruning_ratio=0.1, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256} 
    # pruning_ratio_dict = {model.conv1: 0.2, model.layer2: 0.8}, # customized pruning ratios for layers or blocks 
    ignored_layers=ignored_layers, 
    round_to=1, # It's recommended to round dims/channels to 4x or 8x for acceleration. Please see: https://docs.nvidia.com/deeplearning/performance/dl-performance-convolutional/index.html 
    root_module_types  = (tp.ops.TORCH_CONV, tp.ops.TORCH_LINEAR) 
) 

# base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
# tp.utils.print_tool.before_pruning(model) # or print(model) 
# pruner.step() 
# tp.utils.print_tool.after_pruning(model) # or print(model), this util will show the difference before and after pruning 
# macs, nparams = tp.utils.count_ops_and_params(model, example_inputs) 
# print(f"MACs: {base_macs/1e9} G -> {macs/1e9} G, #Params: {base_nparams/1e6} M -> {nparams/1e6} M") 

In [74]:
for m in model.modules():
    print(m)

ConcatModel(
  (fc): Linear(in_features=4, out_features=4, bias=True)
)
Linear(in_features=4, out_features=4, bias=True)


In [60]:
model.modules

<bound method Module.modules of ConcatModel(
  (fc): Linear(in_features=4, out_features=4, bias=True)
)>

In [None]:
pruner.DG.module2node 

{_ElementWiseOp_0(AddBackward0): <Node: (_ElementWiseOp_0(AddBackward0))>,
 Linear(in_features=4, out_features=4, bias=True): <Node: (fc (Linear(in_features=4, out_features=4, bias=True)))>,
 Parameter containing:
 tensor([-0.6282, -0.9758, -0.3718,  0.5046], requires_grad=True): <Node: (UnwrappedParameter_1 (torch.Size([4])))>,
 _ElementWiseOp_1(AddmmBackward0): <Node: (_ElementWiseOp_1(AddmmBackward0))>,
 _Reshape_2(): <Node: (_Reshape_2())>,
 _ElementWiseOp_3(TBackward0): <Node: (_ElementWiseOp_3(TBackward0))>,
 Parameter containing:
 tensor([ 0.3962,  0.0006, -0.2464,  0.5559], requires_grad=True): <Node: (UnwrappedParameter_0 (torch.Size([4])))>}

In [73]:
for m, n in pruner.DG.module2node.items(): 
    print(n.module) 
    print("---------")

_ElementWiseOp_0(AddBackward0)
---------
Linear(in_features=4, out_features=4, bias=True)
---------
Parameter containing:
tensor([-0.6282, -0.9758, -0.3718,  0.5046], requires_grad=True)
---------
_ElementWiseOp_1(AddmmBackward0)
---------
_Reshape_2()
---------
_ElementWiseOp_3(TBackward0)
---------
Parameter containing:
tensor([ 0.3962,  0.0006, -0.2464,  0.5559], requires_grad=True)
---------


In [71]:
for m, n in pruner.DG.module2node.items(): 
    print(n.dependencies) 
    print("---------")

[prune_out_channels on _ElementWiseOp_0(AddBackward0) => prune_out_channels on fc (Linear(in_features=4, out_features=4, bias=True)), prune_out_channels on _ElementWiseOp_0(AddBackward0) => prune_out_channels on UnwrappedParameter_1 (torch.Size([4]))]
---------
[prune_in_channels on fc (Linear(in_features=4, out_features=4, bias=True)) => prune_out_channels on _ElementWiseOp_1(AddmmBackward0), prune_out_channels on fc (Linear(in_features=4, out_features=4, bias=True)) => prune_out_channels on _ElementWiseOp_0(AddBackward0)]
---------
[prune_out_channels on UnwrappedParameter_1 (torch.Size([4])) => prune_out_channels on _ElementWiseOp_0(AddBackward0)]
---------
[prune_out_channels on _ElementWiseOp_1(AddmmBackward0) => prune_out_channels on _Reshape_2(), prune_out_channels on _ElementWiseOp_1(AddmmBackward0) => prune_out_channels on _ElementWiseOp_3(TBackward0), prune_out_channels on _ElementWiseOp_1(AddmmBackward0) => prune_in_channels on fc (Linear(in_features=4, out_features=4, bias=

In [63]:
for g in pruner.DG.get_all_groups(): 
    print(g._group) 

[(prune_out_channels on fc (Linear(in_features=4, out_features=4, bias=True)) => prune_out_channels on fc (Linear(in_features=4, out_features=4, bias=True)), [0, 1, 2, 3]), (prune_out_channels on fc (Linear(in_features=4, out_features=4, bias=True)) => prune_out_channels on _ElementWiseOp_0(AddBackward0), [0, 1, 2, 3]), (prune_out_channels on _ElementWiseOp_0(AddBackward0) => prune_out_channels on UnwrappedParameter_1 (torch.Size([4])), [0, 1, 2, 3])]
