In [None]:
# Read cnn model in different structure
import torch 
mnist_models = []
cifar10_models = []

def load_spec_model(module, model_index):
    model = getattr(module, 'Model'+model_index)
    return model

for i in range(6):
    params = torch.load('/home/ubuntu/date/hdd4/shadow_model_ckpt/mnist/models%d/shadow_benign_0.model'%i)
    import model_lib.mnist_cnn_model as father_model
    Model = load_spec_model(father_model, str(i))
    m = Model()
    m.load_state_dict(params)
    mnist_models.append(m)
print("mnist model loaded, count: %d" % len(mnist_models))

for i in range(6):
    params = torch.load('/home/ubuntu/date/hdd4/shadow_model_ckpt/cifar10/models%d/shadow_benign_0.model'%i)
    import model_lib.cifar10_cnn_model as father_model
    Model = load_spec_model(father_model, str(i))
    m = Model()
    m.load_state_dict(params)
    cifar10_models.append(m)
print("cifar10 model loaded, count: %d" % len(cifar10_models))

In [None]:
dir(m)

In [None]:
m = mnist_models[1]

In [None]:
m.state_dict().keys()

In [None]:
m.get_submodule('max_pool_1')


In [None]:
import re
from torchinfo import summary
def model_summary(model_type, models, input_size):
    # get info:
    # 1. max node size
    # 2. conv layer node name & num & max_pooling params
    # 3. dense layer node name 
    # final save all info as a json file 
    max_size = None
    model_summary = {}
    cnt = 0
    row_dim = {}
    col_dim = {}
    for m in models:
        batch_size = (1,)
        summary_str = str(summary(m, batch_size+input_size))
        # print(summary_str)
        pattern = '├─(.*?)\['
        maxp_info = re.findall(pattern, summary_str)            
        idx = 0
        layers = []
        pool1_used = False
        cur_row_dim = []
        cur_col_dim = []
        for layer in m.state_dict().keys():
            
            cur_layer_name = layer.split('.')[0]
            # print("layer:", cur_layer_name)

            if 'bias' in layer:
                continue
            # if conv, max_pool exists?
            if 'conv' in layer:
                if 'MaxPool' in maxp_info[idx+1]:
                    # print('maxpool conv layer')
                    mp_name = 'max_pool'
                    mp_info = {}
                    if pool1_used == False and model_type == 'mnist' and 'Model1' in summary_str:
                        mp_name = 'max_pool_1'
                        pool1_used = True
                    if pool1_used == True and model_type == 'mnist' and 'Model1' in summary_str:
                        mp_name = 'max_pool_2'
                    
                    mp_info['kernel_size'] = m.get_submodule(mp_name).kernel_size
                    mp_info['stride'] = m.get_submodule(mp_name).stride
                    mp_info['padding'] = m.get_submodule(mp_name).padding
                    mp_info['dilation'] = m.get_submodule(mp_name).dilation
                    mp_info['ceil_mode'] = m.get_submodule(mp_name).ceil_mode
                    layers.append({'name':cur_layer_name, 'num':m.get_submodule(cur_layer_name).out_channels,
                                  'maxpool':mp_info})
                else:
                    # print("no maxpool layer")
                    layers.append({'name':cur_layer_name, 'num':m.get_submodule(cur_layer_name).out_channels,
                                  'maxpool':None})
                cur_row_dim.append(m.get_submodule(cur_layer_name).kernel_size[0])
                cur_col_dim.append(m.get_submodule(cur_layer_name).kernel_size[1])
            else:
                # dense layer
                # print('dense layer')
                layers.append({'name':cur_layer_name, 'num': 1, 'maxpool':None})
                cur_row_dim.append(m.get_submodule(cur_layer_name).out_features)
                cur_col_dim.append(m.get_submodule(cur_layer_name).out_features)
            
            # get max row & col size
            # if cur_row_dim > row_dim:
            #     row_dim = cur_row_dim
            # if cur_col_dim > col_dim:
            #     col_dim = cur_col_dim
        row_dim[cnt] = cur_row_dim
        col_dim[cnt] = cur_col_dim
        model_summary[cnt] =  layers
        cnt += 1

    return model_summary, row_dim, col_dim

mnist_summary, m_row_dim, m_col_dim = model_summary('mnist', mnist_models, (1, 28, 28))
def max_dim_count(mnist_summary, m_row_dim, m_col_dim):
    for model in range(len(m_row_dim)):
        max_dim = 0
        cell_num = 0
        for cell in range(len(m_row_dim[model])):
            # print(m_row_dim[model][cell], m_col_dim[model][cell])
            dim = m_row_dim[model][cell] * m_col_dim[model][cell]
            if max_dim < dim:
                max_dim = dim
            cell_num += mnist_summary[model][cell]['num']
        print("%d: num of cells: %d, max cell dim: %d" % (model, cell_num, max_dim))
# print(m_row_dim, m_col_dim)
max_dim_count(mnist_summary, m_row_dim, m_col_dim)
cifar10_summary, c_row_dim, c_col_dim = model_summary('cifar10', cifar10_models, (3, 32, 32))
# print(cifar10_summary, c_row_dim, c_col_dim)
max_dim_count(cifar10_summary, c_row_dim, c_col_dim)

In [None]:
import json
def get_model_detail():
    # get_model_detail information and save to json
    mnist_summary, m_row_dim, m_col_dim = model_summary('mnist', mnist_models, (1, 28, 28))
    cifar10_summary, c_row_dim, c_col_dim = model_summary('cifar10', cifar10_models, (3, 32, 32))
    max_row_dim = -1
    if m_row_dim > c_row_dim:
        max_row_dim = m_row_dim
    else:
        max_row_dim = c_row_dim
    max_col_dim = -1
    if m_col_dim > c_col_dim:
        max_col_dim = m_col_dim
    else:
        max_col_dim = c_col_dim
    model_detail = {}
    model_detail['mnist'] = mnist_summary
    model_detail['cifar10'] = cifar10_summary
    model_detail['max_size'] = [max_row_dim, max_col_dim]
    print(model_detail)
    model_detail_path = "./intermediate_data/model_detail.json"
    with open(model_detail_path, 'w') as f:
        json.dump(model_detail, f)
    return model_detail
model_detail = get_model_detail()

In [None]:
import json
model_detail = []
model_detail_path = "./intermediate_data/model_detail.json"
with open(model_detail_path, 'r') as f:
    model_detail = json.load(f)

max_num = 0
num_list = []
for key, value in zip(model_detail.keys(), model_detail.values()):
    if key not in ['mnist', 'cifar10']:
        continue
    cur_num = 0
    for model in value.values():
        for layer in model:
            print(layer)
            cur_num += layer['num']
        if cur_num > max_num:
            max_num = cur_num
            num_list.append(cur_num)
print(max_num)
print(num_list)

In [None]:
for weight, bias in zip(m.get_submodule('conv1').weight, 
                        m.get_submodule('conv1').bias):
    print(weight[0])
    print(bias)
    w = weight[0] + bias
    print(w)
    break

In [None]:
dir(torch)

In [None]:
print(m.get_submodule('output').weight.t() + m.get_submodule('output').bias)

In [None]:
# cnn2graph
# mnsit input size: 1,28,28
# cifar10 input size: 3,32,32
from torchinfo import summary
import os 
import dgl


def all_u_to_v(src, dst):
    ret = []
    for v in dst:
        ret += [[u, v] for u in src]
    return ret

def cnn2graph(model, model_info, y):
    # convert cnn model to a dgl graph
    # model: model weight data
    # model_info: model struct info
    
    label = torch.IntTensor([y])
    layers = []
    all_node_feats = []
    all_edges = []
    cnt = 0
    with torch.no_grad():
        for i in range(len(model_info)):
            cur_layer_info = model_info[i]
            cur_layer_node = []
            
            if 'conv' in cur_layer_info['name']:
                # construct cur layer nodes
                for weight, bias in zip(model.get_submodule(cur_layer_info['name']).weight, 
                                        model.get_submodule(cur_layer_info['name']).bias):
                    cur_layer_node.append(cnt)
                    cnt += 1
                    # featue resize?
                    all_node_feats.append(weight[0] + bias)


            else:
                # construct dense layer node
                cur_layer_node.append(cnt)
                cnt += 1
                # feature resize?
                all_node_feats.append(model.get_submodule(cur_layer_info['name']).weight.t() + model.get_submodule(cur_layer_info['name']).bias)
            layers.append(cur_layer_node)
            
    # get all edges
    for idx in range(len(layers)):
        if idx < len(layers) - 1:
            edges = all_u_to_v(layers[idx], layers[idx+1])
            all_edges += edges
    
    all_edges = torch.tensor(all_edges).t()
    u, v = all_edges[0], all_edges[1]
    g = dgl.graph((u,v))
    # g.ndata['x'] = torch.stack(all_node_feats)
    print(g)
    return all_node_feats
all_node_feats = []
all_node_feats = cnn2graph(mnist_models[0], model_detail['mnist'][0], 1)

In [None]:
val= torch.tensor([item.cpu().detach().numpy() for item in all_node_feats]).cuda()

In [None]:
all_node_feats