In [2]:
from genericpath import isfile
import json
import os
if __name__ == '__main__':
    #os.sys.path.append('./pytorch_geometric/torch_geometric')
    os.sys.path.append('./src')
from src.model.model import MMGNet
from src.utils.config import Config
from utils import util
import torch
import argparse


In [3]:
def load_config():
    # load config file
    config = Config("./config/mmgnet.json")
    #print(config)
    if 'NAME' not in config:
        config_name = os.path.basename('./config/mmgnet.json')
        if len(config_name) > len('config_'):
            name = config_name[len('config_'):]
            name = os.path.splitext(name)[0]
            translation_table = dict.fromkeys(map(ord, '!@#$'), None)
            name = name.translate(translation_table)
            config['NAME'] = name            
    config.LOADBEST = ''
    config.MODE = 'train'
    config.exp = 'test'

    return config

In [4]:
config = load_config()
os.environ["CUDA_LAUNCH_BLOCKING"] = "0"
util.set_random_seed(config.SEED)

# if config.VERBOSE:
#     print(config)

model = MMGNet(config)
model


=== 160 classes ===
| 0             armchair:0.006|| 1             backpack:0.019|
| 2                  bag:0.005|| 3                 ball:0.143|
| 4                  bar:0.071|| 5                basin:0.286|
| 6               basket:0.009|| 7         bath cabinet:0.016|
| 8              bathtub:0.036|| 9                  bed:0.009|
|10        bedside table:0.061||11                bench:0.011|
|12                bidet:0.200||13                  bin:0.037|
|14              blanket:0.007||15               blinds:0.014|
|16                board:0.087||17                 book:0.050|
|18                books:0.080||19            bookshelf:0.065|
|20               bottle:0.041||21                  box:0.001|
|22                bread:0.286||23               bucket:0.015|
|24              cabinet:0.004||25               carpet:0.061|
|26              ceiling:0.002||27                chair:0.001|
|28             cleanser:0.500||29                clock:0.059|
|30               closet:0.083||31 

<src.model.model.MMGNet at 0x7fab997a3af0>

In [5]:
model.model

Mmgnet_teacher(
  (obj_encoder): PointNetfeat(
    (relu): ReLU()
    (conv1): Conv1d(3, 64, kernel_size=(1,), stride=(1,))
    (conv2): Conv1d(64, 128, kernel_size=(1,), stride=(1,))
    (conv3): Conv1d(128, 768, kernel_size=(1,), stride=(1,))
  )
  (rel_encoder_2d): PointNetfeat(
    (relu): ReLU()
    (conv1): Conv1d(11, 64, kernel_size=(1,), stride=(1,))
    (conv2): Conv1d(64, 128, kernel_size=(1,), stride=(1,))
    (conv3): Conv1d(128, 512, kernel_size=(1,), stride=(1,))
  )
  (rel_encoder_3d): PointNetfeat(
    (relu): ReLU()
    (conv1): Conv1d(11, 64, kernel_size=(1,), stride=(1,))
    (conv2): Conv1d(64, 128, kernel_size=(1,), stride=(1,))
    (conv3): Conv1d(128, 512, kernel_size=(1,), stride=(1,))
  )
  (mmg): MMG(
    (self_attn): ModuleList(
      (0): MultiHeadAttention(
        (attention): ScaledDotProductAttention(
          (fc_q): Linear(in_features=512, out_features=512, bias=True)
          (fc_k): Linear(in_features=512, out_features=512, bias=True)
          (fc

In [13]:
model.model.mmg.self_attn

ModuleList(
  (0): MultiHeadAttention(
    (attention): ScaledDotProductAttention(
      (fc_q): Linear(in_features=512, out_features=512, bias=True)
      (fc_k): Linear(in_features=512, out_features=512, bias=True)
      (fc_v): Linear(in_features=512, out_features=512, bias=True)
      (fc_o): Linear(in_features=512, out_features=512, bias=True)
    )
    (dropout): Dropout(p=0.1, inplace=False)
    (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (1): MultiHeadAttention(
    (attention): ScaledDotProductAttention(
      (fc_q): Linear(in_features=512, out_features=512, bias=True)
      (fc_k): Linear(in_features=512, out_features=512, bias=True)
      (fc_v): Linear(in_features=512, out_features=512, bias=True)
      (fc_o): Linear(in_features=512, out_features=512, bias=True)
    )
    (dropout): Dropout(p=0.1, inplace=False)
    (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
)

In [12]:
model.model.mmg.self_attn[0]

GraphEdgeAttenNetwork(
  (index_get): Gen_Index()
  (index_aggr): Aggre_Index()
  (edgeatten): MultiHeadedEdgeAttention(
    (nn_edge): Sequential(
      (0): Linear(in_features=1536, out_features=1024, bias=True)
      (1): ReLU()
      (2): Linear(in_features=1024, out_features=512, bias=True)
    )
    (nn): mySequential(
      (0): Conv1d(128, 128, kernel_size=(1,), stride=(1,))
      (1): ReLU()
      (2): Dropout(p=0.5, inplace=False)
      (3): Conv1d(128, 32, kernel_size=(1,), stride=(1,))
    )
    (proj_edge): Sequential(
      (0): Linear(in_features=512, out_features=512, bias=True)
    )
    (proj_query): Sequential(
      (0): Linear(in_features=512, out_features=512, bias=True)
    )
    (proj_value): Sequential(
      (0): Linear(in_features=512, out_features=256, bias=True)
    )
  )
  (prop): Sequential(
    (0): Linear(in_features=768, out_features=768, bias=True)
    (1): ReLU()
    (2): Linear(in_features=768, out_features=512, bias=True)
  )
)

In [11]:
model.model.mmg.self_attn[0].attention.fc_o

Linear(in_features=512, out_features=512, bias=True)

In [20]:
import torch_pruning as tp

In [29]:
head_id = 0
for m in model.model.mmg.modules():
    print("Head #%d"%head_id)
    print("[Before Pruning] Num Heads: %d, Head Dim: %d =>"%(m.num_heads, m.head_dim))
    print()
    head_id+=1

Head #0


AttributeError: 'MMG_student' object has no attribute 'head_dim'