In [1]:
from data_generator import ssp_data
import pickle

import os
import random
import argparse

import torch
import torch.nn as nn
import numpy as np

from sklearn.metrics import f1_score
import pdb
import net
import pruning
import copy
from scipy.sparse import coo_matrix
import warnings

In [2]:
file = open('data_pickle', 'rb')
ssp_obj = pickle.load(file)
file.close()

In [3]:
import torch
import torch.nn as nn
import pdb
import copy

def torch_normalize_adj(adj):
    adj = adj + torch.eye(adj.shape[0]).cpu()
    rowsum = adj.sum(1)
    d_inv_sqrt = torch.pow(rowsum, -0.5).flatten()
    d_inv_sqrt[torch.isinf(d_inv_sqrt)] = 0.0
    d_mat_inv_sqrt = torch.diag(d_inv_sqrt).cpu()
    return adj.mm(d_mat_inv_sqrt).t().mm(d_mat_inv_sqrt)

class net_gcn(nn.Module):

    def __init__(self, embedding_dim, adj):
        super().__init__()

        self.layer_num = len(embedding_dim) - 1
        self.net_layer = nn.ModuleList([nn.Linear(embedding_dim[ln], embedding_dim[ln+1], bias=False) for ln in range(self.layer_num)])
        self.relu = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout(p=0.5)
        self.adj_nonzero = torch.nonzero(adj, as_tuple=False).shape[0]
        self.adj_mask1_train = nn.Parameter(self.generate_adj_mask(adj))
        self.adj_mask2_fixed = nn.Parameter(self.generate_adj_mask(adj), requires_grad=False)
        self.normalize = torch_normalize_adj
    
    def forward(self, x, adj, val_test=False):
        
        adj = torch.mul(adj, self.adj_mask1_train)
        adj = torch.mul(adj, self.adj_mask2_fixed)
        adj = self.normalize(adj)
        #adj = torch.mul(adj, self.adj_mask2_fixed)
        for ln in range(self.layer_num):
            x = torch.mm(adj, x)
            x = self.net_layer[ln](x)
            if ln == self.layer_num - 1:
                break
            x = self.relu(x)
            if val_test:
                continue
            x = self.dropout(x)
        return x

    def generate_adj_mask(self, input_adj):
        
        sparse_adj = input_adj
        zeros = torch.zeros_like(sparse_adj)
        ones = torch.ones_like(sparse_adj)
        mask = torch.where(sparse_adj != 0, ones, zeros)
        return mask

In [4]:
# pruning.setup_seed(seed)

# adj, features, labels, idx_train, idx_val, idx_test = load_data(args['dataset'])

args = {
    'embedding_dim': [74100, 512, 74100],
    'lr': 0.01,
    'weight_decay': 5e-4,
    'pruning_percent_wei': 0.2,
    'pruning_percent_adj': 0.05,
    'total_epoch': 2000,
    's1': 1e-6,
    's2': 1e-3,
    'init_soft_mask_type' : 'all_one',
    'weight_dir': False,
}

def run_fix_mask(args, seed, rewind_weight_mask):

    pruning.setup_seed(seed)

    adj = ssp_obj.train_data.edge_index
    features = ssp_obj.train_data.x.float()
    labels = ssp_obj.train_data.y.float()

    node_num = features.size()[0]
    class_num = labels.numpy().max() + 1

    adj = adj.cpu()
    features = features.cpu()
    labels = labels.cpu()
    loss_func = nn.MSELoss()

    net_gcn = net.net_gcn(embedding_dim=args['embedding_dim'], adj=adj)
    pruning.add_mask(net_gcn)
    net_gcn = net_gcn.cpu()
    net_gcn.load_state_dict(rewind_weight_mask)
    adj_spar, wei_spar = pruning.print_sparsity(net_gcn)

    for name, param in net_gcn.named_parameters():
        if 'mask' in name:
            param.requires_grad = False

    optimizer = torch.optim.Adam(net_gcn.parameters(), lr=args['lr'], weight_decay=args['weight_decay'])
    acc_test = 0.0
    best_val_acc = {'val_acc': 0, 'epoch' : 0, 'test_acc': 0}

    for epoch in range(200):

        optimizer.zero_grad()
        output = net_gcn(features, adj)
        loss = loss_func(output, labels.float())
        loss.backward()
        optimizer.step()
        with torch.no_grad():
            output = net_gcn(features, adj, val_test=True)
            acc_val = f1_score(labels.cpu().numpy(), output.cpu().numpy().argmax(axis=1), average='micro')
            # acc_test = f1_score(labels[idx_test].cpu().numpy(), output[idx_test].cpu().numpy().argmax(axis=1), average='micro')
            if acc_val > best_val_acc['val_acc']:
                best_val_acc['val_acc'] = acc_val
                best_val_acc['test_acc'] = acc_test
                best_val_acc['epoch'] = epoch

        print("(Fix Mask) Epoch:[{}] Val:[{:.2f}] Test:[{:.2f}] | Final Val:[{:.2f}] Test:[{:.2f}] at Epoch:[{}]"
                    .format(epoch, acc_val * 100, acc_test * 100, 
                                best_val_acc['val_acc'] * 100, 
                                best_val_acc['test_acc'] * 100, 
                                best_val_acc['epoch']))
        return best_val_acc['val_acc'], best_val_acc['test_acc'], best_val_acc['epoch'], adj_spar, wei_spar
    
def run_get_mask(args, seed, imp_num, rewind_weight_mask=None):

    pruning.setup_seed(seed)
    # adj, features, labels, idx_train, idx_val, idx_test = load_data(args['dataset'])

    adj = ssp_obj.train_data.edge_index
    features = ssp_obj.train_data.x
    labels = ssp_obj.train_data.y
    
    node_num = features.size()[0]
    class_num = labels.numpy().max() + 1

    adj = adj.cpu()
    features = features.cpu()
    labels = labels.cpu()
    loss_func = nn.CrossEntropyLoss()

    net_gcn = net.net_gcn(embedding_dim=args['embedding_dim'], adj=adj)
    pruning.add_mask(net_gcn)
    net_gcn = net_gcn.cpu()

    if args['weight_dir']:
        print("load : {}".format(args['weight_dir']))
        encoder_weight = {}
        cl_ckpt = torch.load(args['weight_dir'], map_location='cpu')
        encoder_weight['weight_orig_weight'] = cl_ckpt['gcn.fc.weight']
        ori_state_dict = net_gcn.net_layer[0].state_dict()
        ori_state_dict.update(encoder_weight)
        net_gcn.net_layer[0].load_state_dict(ori_state_dict)

    if rewind_weight_mask:
        net_gcn.load_state_dict(rewind_weight_mask)
        if not args['rewind_soft_mask'] or args['init_soft_mask_type'] == 'all_one':
            pruning.soft_mask_init(net_gcn, args['init_soft_mask_type'], seed)
        adj_spar, wei_spar = pruning.print_sparsity(net_gcn)
    else:
        pruning.soft_mask_init(net_gcn, args['init_soft_mask_type'], seed)

    optimizer = torch.optim.Adam(net_gcn.parameters(), lr=args['lr'], weight_decay=args['weight_decay'])

    acc_test = 0.0
    best_val_acc = {'val_acc': 0, 'epoch' : 0, 'test_acc':0}
    rewind_weight = copy.deepcopy(net_gcn.state_dict())
    for epoch in range(args['total_epoch']):
        
        optimizer.zero_grad()
        output = net_gcn(features, adj)
        loss = loss_func(output, labels.float())
        loss.backward()
        pruning.subgradient_update_mask(net_gcn, args) # l1 norm
        optimizer.step()
        with torch.no_grad():
            output = net_gcn(features, adj, val_test=True)
            acc_val = f1_score(labels.cpu().numpy(), output.cpu().numpy().argmax(axis=1), average='micro')
            # acc_test = f1_score(labels[idx_test].cpu().numpy(), output[idx_test].cpu().numpy().argmax(axis=1), average='micro')
            if acc_val > best_val_acc['val_acc']:
                best_val_acc['test_acc'] = acc_test
                best_val_acc['val_acc'] = acc_val
                best_val_acc['epoch'] = epoch
                best_epoch_mask = pruning.get_final_mask_epoch(net_gcn, adj_percent=args['pruning_percent_adj'], 
                                                                        wei_percent=args['pruning_percent_wei'])

            print("(Get Mask) Epoch:[{}] Val:[{:.2f}] Test:[{:.2f}] | Best Val:[{:.2f}] Test:[{:.2f}] at Epoch:[{}]"
                 .format(epoch, acc_val * 100, acc_test * 100, 
                                best_val_acc['val_acc'] * 100,  
                                best_val_acc['test_acc'] * 100,
                                best_val_acc['epoch']))

    return best_epoch_mask, rewind_weight


In [5]:
rewind_weight = None
seed = 1
for p in range(20):
    
    final_mask_dict, rewind_weight = run_get_mask(args, seed, p, rewind_weight)

    rewind_weight['adj_mask1_train'] = final_mask_dict['adj_mask']
    rewind_weight['adj_mask2_fixed'] = final_mask_dict['adj_mask']
    rewind_weight['net_layer.0.weight_mask_train'] = final_mask_dict['weight1_mask']
    rewind_weight['net_layer.0.weight_mask_fixed'] = final_mask_dict['weight1_mask']
    rewind_weight['net_layer.1.weight_mask_train'] = final_mask_dict['weight2_mask']
    rewind_weight['net_layer.1.weight_mask_fixed'] = final_mask_dict['weight2_mask']

    best_acc_val, final_acc_test, final_epoch_list, adj_spar, wei_spar = run_fix_mask(args, seed, rewind_weight)
    print("=" * 120)
    print("syd : Sparsity:[{}], Best Val:[{:.2f}] at epoch:[{}] | Final Test Acc:[{:.2f}] Adj:[{:.2f}%] Wei:[{:.2f}%]"
        .format(p + 1, best_acc_val * 100, final_epoch_list, final_acc_test * 100, adj_spar, wei_spar))
    print("=" * 120)

torch.float32 torch.int64


RuntimeError: result type Float can't be cast to the desired output type Long