In [1]:
from scipy.sparse import data
import torch
import torch.nn as nn
import numpy as np
import scipy.sparse as sp
import scipy.io as sio
from sklearn.metrics import roc_auc_score
from datetime import datetime
import argparse

from model import Dominant

In [2]:
def normalize_adj(adj):
    """Symmetrically normalize adjacency matrix."""
    #co ordinate matrix 
    adj = sp.coo_matrix(adj)
    rowsum = np.array(adj.sum(1))
    d_inv_sqrt = np.power(rowsum, -0.5).flatten()
    d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
    d_mat_inv_sqrt = sp.diags(d_inv_sqrt)
    return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo()

def load_anomaly_detection_dataset(dataset, datadir='data'):
    
    data_mat = sio.loadmat(f'{datadir}/{dataset}.mat')
    adj = data_mat['Network']
    feat = data_mat['Attributes']
    truth = data_mat['Label']
    truth = truth.flatten()

    adj_norm = normalize_adj(adj + sp.eye(adj.shape[0]))
    adj_norm = adj_norm.toarray()
    adj = adj + sp.eye(adj.shape[0])
    adj = adj.toarray()
    feat = feat.toarray()
    return adj_norm, feat, truth, adj

In [3]:
import math

from torch.nn.modules.module import Module
from torch.nn.parameter import Parameter

class GraphConvolution(Module):
    """
    Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
    """

    def __init__(self, in_features, out_features, bias=True):
        super(GraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.FloatTensor(in_features, out_features))
        if bias:
            self.bias = Parameter(torch.FloatTensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input, adj):
        support = torch.mm(input, self.weight)
        output = torch.spmm(adj, support)
        if self.bias is not None:
            return output + self.bias
        else:
            return output

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ' -> ' \
               + str(self.out_features) + ')'

In [4]:
import torch.nn as nn
import torch.nn.functional as F
import torch

class Encoder(nn.Module):
    def __init__(self, nfeat, nhid, dropout):
        super(Encoder, self).__init__()

        self.gc1 = GraphConvolution(nfeat, nhid)
        self.gc2 = GraphConvolution(nhid, nhid)
        self.dropout = dropout

    def forward(self, x, adj):
        x = F.relu(self.gc1(x, adj))
        x = F.dropout(x, self.dropout, training=self.training)
        x = F.relu(self.gc2(x, adj))

        return x

class Attribute_Decoder(nn.Module):
    def __init__(self, nfeat, nhid, dropout):
        super(Attribute_Decoder, self).__init__()

        self.gc1 = GraphConvolution(nhid, nhid)
        self.gc2 = GraphConvolution(nhid, nfeat)
        self.dropout = dropout

    def forward(self, x, adj):
        x = F.relu(self.gc1(x, adj))
        x = F.dropout(x, self.dropout, training=self.training)
        x = F.relu(self.gc2(x, adj))

        return x

class Structure_Decoder(nn.Module):
    def __init__(self, nhid, dropout):
        super(Structure_Decoder, self).__init__()

        self.gc1 = GraphConvolution(nhid, nhid)
        self.dropout = dropout

    def forward(self, x, adj):
        x = F.relu(self.gc1(x, adj))
        x = F.dropout(x, self.dropout, training=self.training)
        x = x @ x.T        #this is matrix multiplication

        return x

class Dominant(nn.Module):
    def __init__(self, feat_size, hidden_size, dropout):
        super(Dominant, self).__init__()
        
        self.shared_encoder = Encoder(feat_size, hidden_size, dropout)
        self.attr_decoder = Attribute_Decoder(feat_size, hidden_size, dropout)
        self.struct_decoder = Structure_Decoder(hidden_size, dropout)
    
    def forward(self, x, adj):
        # encode
        x = self.shared_encoder(x, adj)
        # decode feature matrix
        x_hat = self.attr_decoder(x, adj)
        # decode adjacency matrix
        struct_reconstructed = self.struct_decoder(x, adj)
        # return reconstructed matrices
        return struct_reconstructed, x_hat

In [5]:
def loss_func(adj, A_hat, attrs, X_hat, alpha):
    # Attribute reconstruction loss
    diff_attribute = torch.pow(X_hat - attrs, 2)
    attribute_reconstruction_errors = torch.sqrt(torch.sum(diff_attribute, 1))
    attribute_cost = torch.mean(attribute_reconstruction_errors)

    # structure reconstruction loss
    diff_structure = torch.pow(A_hat - adj, 2)
    structure_reconstruction_errors = torch.sqrt(torch.sum(diff_structure, 1))
    structure_cost = torch.mean(structure_reconstruction_errors)


    cost =  alpha * attribute_reconstruction_errors + (1-alpha) * structure_reconstruction_errors

    return cost, structure_cost, attribute_cost

In [6]:
def train_dominant(dataset="BlogCatalog", hidden_dim=64, max_epoch=100, lr=5e-3, dropout=0.3, alpha=0.8, device="cpu"):
    adj, attrs, label, adj_label = load_anomaly_detection_dataset(dataset)
    adj = torch.FloatTensor(adj)
    adj_label = torch.FloatTensor(adj_label)
    attrs = torch.FloatTensor(attrs)
    
    model = Dominant(feat_size = attrs.size(1), hidden_size = hidden_dim, dropout = dropout)


    if device == 'cuda':
        device = torch.device(device)
        adj = adj.to(device)
        adj_label = adj_label.to(device)
        attrs = attrs.to(device)
        model = model.cuda()
        
    
    optimizer =  torch.optim.Adam(model.parameters(), lr=lr)
    
    for epoch in range(max_epoch):
        model.train()
        optimizer.zero_grad()
        # returns the reconstructed matrices
        A_hat, X_hat = model(attrs, adj)
        loss, struct_loss, feat_loss = loss_func(adj_label, A_hat, attrs, X_hat, alpha)
        l = torch.mean(loss)
        l.backward()
        optimizer.step()        
        print("Epoch:", '%04d' % (epoch), 
              "train_loss=", "{:.5f}".format(l.item()), 
              "train/struct_loss=", "{:.5f}".format(struct_loss.item()),
              "train/feat_loss=", "{:.5f}".format(feat_loss.item()))

        if epoch%10 == 0 or epoch == max_epoch - 1:
            model.eval()
            A_hat, X_hat = model(attrs, adj)
            loss, struct_loss, feat_loss = loss_func(adj_label, A_hat, attrs, X_hat, alpha)
            score = loss.detach().cpu().numpy()
            print("Epoch:", '%04d' % (epoch), 'Auc', roc_auc_score(label, score))

In [7]:
train_dominant(dataset="BlogCatalog", hidden_dim=64, max_epoch=100, lr=5e-3, dropout=0.3, alpha=0.8, device="cpu")

Epoch: 0000 train_loss= 4.48574 train/struct_loss= 17.38049 train/feat_loss= 1.26205
Epoch: 0000 Auc 0.8044805303356253
Epoch: 0001 train_loss= 3.20059 train/struct_loss= 11.16748 train/feat_loss= 1.20887
Epoch: 0002 train_loss= 2.67934 train/struct_loss= 8.69260 train/feat_loss= 1.17603
Epoch: 0003 train_loss= 2.51623 train/struct_loss= 7.88139 train/feat_loss= 1.17494
Epoch: 0004 train_loss= 2.47415 train/struct_loss= 7.67110 train/feat_loss= 1.17491
Epoch: 0005 train_loss= 2.47587 train/struct_loss= 7.68009 train/feat_loss= 1.17482
Epoch: 0006 train_loss= 2.47472 train/struct_loss= 7.67447 train/feat_loss= 1.17478
Epoch: 0007 train_loss= 2.46851 train/struct_loss= 7.64371 train/feat_loss= 1.17471
Epoch: 0008 train_loss= 2.47094 train/struct_loss= 7.65580 train/feat_loss= 1.17473
Epoch: 0009 train_loss= 2.47336 train/struct_loss= 7.66787 train/feat_loss= 1.17473
Epoch: 0010 train_loss= 2.47230 train/struct_loss= 7.66268 train/feat_loss= 1.17470
Epoch: 0010 Auc 0.8140317510776897
Epoc

Epoch: 0094 train_loss= 2.46595 train/struct_loss= 7.63103 train/feat_loss= 1.17468
Epoch: 0095 train_loss= 2.46584 train/struct_loss= 7.63047 train/feat_loss= 1.17468
Epoch: 0096 train_loss= 2.46606 train/struct_loss= 7.63158 train/feat_loss= 1.17468
Epoch: 0097 train_loss= 2.46597 train/struct_loss= 7.63113 train/feat_loss= 1.17468
Epoch: 0098 train_loss= 2.46608 train/struct_loss= 7.63166 train/feat_loss= 1.17468
Epoch: 0099 train_loss= 2.46588 train/struct_loss= 7.63067 train/feat_loss= 1.17468
Epoch: 0099 Auc 0.8141646638403293


In [13]:
train_dominant(dataset="BlogCatalog", hidden_dim=64, max_epoch=500, lr=5e-4, dropout=0.3, alpha=0.8, device="cpu")

Epoch: 0000 train_loss= 3.80194 train/struct_loss= 13.96719 train/feat_loss= 1.26063
Epoch: 0000 Auc 0.8076820836336431
Epoch: 0001 train_loss= 3.63766 train/struct_loss= 13.25442 train/feat_loss= 1.23347
Epoch: 0002 train_loss= 3.53396 train/struct_loss= 12.81713 train/feat_loss= 1.21317
Epoch: 0003 train_loss= 3.41595 train/struct_loss= 12.28494 train/feat_loss= 1.19870
Epoch: 0004 train_loss= 3.30913 train/struct_loss= 11.79079 train/feat_loss= 1.18872
Epoch: 0005 train_loss= 3.22164 train/struct_loss= 11.37924 train/feat_loss= 1.18224
Epoch: 0006 train_loss= 3.14188 train/struct_loss= 10.99608 train/feat_loss= 1.17834
Epoch: 0007 train_loss= 3.07212 train/struct_loss= 10.65597 train/feat_loss= 1.17615
Epoch: 0008 train_loss= 2.99504 train/struct_loss= 10.27493 train/feat_loss= 1.17507
Epoch: 0009 train_loss= 2.93523 train/struct_loss= 9.97784 train/feat_loss= 1.17458
Epoch: 0010 train_loss= 2.87493 train/struct_loss= 9.67710 train/feat_loss= 1.17438
Epoch: 0010 Auc 0.80831581716684

Epoch: 0094 train_loss= 2.46428 train/struct_loss= 7.62543 train/feat_loss= 1.17400
Epoch: 0095 train_loss= 2.46446 train/struct_loss= 7.62631 train/feat_loss= 1.17400
Epoch: 0096 train_loss= 2.46435 train/struct_loss= 7.62578 train/feat_loss= 1.17399
Epoch: 0097 train_loss= 2.46421 train/struct_loss= 7.62509 train/feat_loss= 1.17399
Epoch: 0098 train_loss= 2.46424 train/struct_loss= 7.62525 train/feat_loss= 1.17399
Epoch: 0099 train_loss= 2.46420 train/struct_loss= 7.62506 train/feat_loss= 1.17399
Epoch: 0100 train_loss= 2.46437 train/struct_loss= 7.62591 train/feat_loss= 1.17399
Epoch: 0100 Auc 0.81424310977498
Epoch: 0101 train_loss= 2.46432 train/struct_loss= 7.62566 train/feat_loss= 1.17399
Epoch: 0102 train_loss= 2.46430 train/struct_loss= 7.62555 train/feat_loss= 1.17399
Epoch: 0103 train_loss= 2.46438 train/struct_loss= 7.62595 train/feat_loss= 1.17399
Epoch: 0104 train_loss= 2.46415 train/struct_loss= 7.62481 train/feat_loss= 1.17399
Epoch: 0105 train_loss= 2.46423 train/struc

Epoch: 0188 train_loss= 2.46348 train/struct_loss= 7.62169 train/feat_loss= 1.17393
Epoch: 0189 train_loss= 2.46333 train/struct_loss= 7.62091 train/feat_loss= 1.17393
Epoch: 0190 train_loss= 2.46346 train/struct_loss= 7.62160 train/feat_loss= 1.17393
Epoch: 0190 Auc 0.8143475901682923
Epoch: 0191 train_loss= 2.46331 train/struct_loss= 7.62085 train/feat_loss= 1.17393
Epoch: 0192 train_loss= 2.46333 train/struct_loss= 7.62093 train/feat_loss= 1.17393
Epoch: 0193 train_loss= 2.46315 train/struct_loss= 7.62006 train/feat_loss= 1.17393
Epoch: 0194 train_loss= 2.46318 train/struct_loss= 7.62020 train/feat_loss= 1.17393
Epoch: 0195 train_loss= 2.46322 train/struct_loss= 7.62041 train/feat_loss= 1.17393
Epoch: 0196 train_loss= 2.46329 train/struct_loss= 7.62076 train/feat_loss= 1.17393
Epoch: 0197 train_loss= 2.46322 train/struct_loss= 7.62041 train/feat_loss= 1.17393
Epoch: 0198 train_loss= 2.46355 train/struct_loss= 7.62207 train/feat_loss= 1.17393
Epoch: 0199 train_loss= 2.46353 train/str

Epoch: 0282 train_loss= 2.46259 train/struct_loss= 7.61745 train/feat_loss= 1.17388
Epoch: 0283 train_loss= 2.46259 train/struct_loss= 7.61744 train/feat_loss= 1.17387
Epoch: 0284 train_loss= 2.46251 train/struct_loss= 7.61707 train/feat_loss= 1.17387
Epoch: 0285 train_loss= 2.46260 train/struct_loss= 7.61750 train/feat_loss= 1.17387
Epoch: 0286 train_loss= 2.46252 train/struct_loss= 7.61713 train/feat_loss= 1.17387
Epoch: 0287 train_loss= 2.46256 train/struct_loss= 7.61729 train/feat_loss= 1.17387
Epoch: 0288 train_loss= 2.46261 train/struct_loss= 7.61758 train/feat_loss= 1.17387
Epoch: 0289 train_loss= 2.46255 train/struct_loss= 7.61726 train/feat_loss= 1.17387
Epoch: 0290 train_loss= 2.46237 train/struct_loss= 7.61640 train/feat_loss= 1.17387
Epoch: 0290 Auc 0.8145113332109257
Epoch: 0291 train_loss= 2.46237 train/struct_loss= 7.61639 train/feat_loss= 1.17387
Epoch: 0292 train_loss= 2.46248 train/struct_loss= 7.61695 train/feat_loss= 1.17387
Epoch: 0293 train_loss= 2.46237 train/str

Epoch: 0376 train_loss= 2.46198 train/struct_loss= 7.61477 train/feat_loss= 1.17379
Epoch: 0377 train_loss= 2.46193 train/struct_loss= 7.61450 train/feat_loss= 1.17379
Epoch: 0378 train_loss= 2.46205 train/struct_loss= 7.61510 train/feat_loss= 1.17379
Epoch: 0379 train_loss= 2.46192 train/struct_loss= 7.61446 train/feat_loss= 1.17378
Epoch: 0380 train_loss= 2.46209 train/struct_loss= 7.61534 train/feat_loss= 1.17378
Epoch: 0380 Auc 0.8146168412802377
Epoch: 0381 train_loss= 2.46189 train/struct_loss= 7.61433 train/feat_loss= 1.17378
Epoch: 0382 train_loss= 2.46211 train/struct_loss= 7.61544 train/feat_loss= 1.17378
Epoch: 0383 train_loss= 2.46219 train/struct_loss= 7.61583 train/feat_loss= 1.17378
Epoch: 0384 train_loss= 2.46219 train/struct_loss= 7.61582 train/feat_loss= 1.17378
Epoch: 0385 train_loss= 2.46207 train/struct_loss= 7.61522 train/feat_loss= 1.17378
Epoch: 0386 train_loss= 2.46199 train/struct_loss= 7.61483 train/feat_loss= 1.17378
Epoch: 0387 train_loss= 2.46194 train/str

Epoch: 0470 train_loss= 2.46149 train/struct_loss= 7.61271 train/feat_loss= 1.17368
Epoch: 0470 Auc 0.8146600036722289
Epoch: 0471 train_loss= 2.46166 train/struct_loss= 7.61359 train/feat_loss= 1.17368
Epoch: 0472 train_loss= 2.46147 train/struct_loss= 7.61262 train/feat_loss= 1.17368
Epoch: 0473 train_loss= 2.46149 train/struct_loss= 7.61276 train/feat_loss= 1.17368
Epoch: 0474 train_loss= 2.46156 train/struct_loss= 7.61307 train/feat_loss= 1.17368
Epoch: 0475 train_loss= 2.46140 train/struct_loss= 7.61232 train/feat_loss= 1.17368
Epoch: 0476 train_loss= 2.46159 train/struct_loss= 7.61321 train/feat_loss= 1.17368
Epoch: 0477 train_loss= 2.46152 train/struct_loss= 7.61292 train/feat_loss= 1.17368
Epoch: 0478 train_loss= 2.46150 train/struct_loss= 7.61282 train/feat_loss= 1.17367
Epoch: 0479 train_loss= 2.46142 train/struct_loss= 7.61237 train/feat_loss= 1.17368
Epoch: 0480 train_loss= 2.46165 train/struct_loss= 7.61354 train/feat_loss= 1.17368
Epoch: 0480 Auc 0.8146702804322268
Epoch:

In [19]:
train_dominant(dataset="BlogCatalog", hidden_dim=64, max_epoch=500, lr=5e-2, dropout=0.3, alpha=0.8, device="mps")

Epoch: 0000 train_loss= 4.34323 train/struct_loss= 16.68764 train/feat_loss= 1.25712
Epoch: 0000 Auc 0.6479319048180192
Epoch: 0001 train_loss= 14.08281 train/struct_loss= 52.20726 train/feat_loss= 4.55171
Epoch: 0002 train_loss= 2.67803 train/struct_loss= 7.71949 train/feat_loss= 1.41767
Epoch: 0003 train_loss= 2.50623 train/struct_loss= 7.68096 train/feat_loss= 1.21255
Epoch: 0004 train_loss= 2.47950 train/struct_loss= 7.67807 train/feat_loss= 1.17986
Epoch: 0005 train_loss= 2.51940 train/struct_loss= 7.88895 train/feat_loss= 1.17702
Epoch: 0006 train_loss= 2.47858 train/struct_loss= 7.68228 train/feat_loss= 1.17765
Epoch: 0007 train_loss= 2.48236 train/struct_loss= 7.71183 train/feat_loss= 1.17500
Epoch: 0008 train_loss= 2.47954 train/struct_loss= 7.69773 train/feat_loss= 1.17500
Epoch: 0009 train_loss= 2.47969 train/struct_loss= 7.69844 train/feat_loss= 1.17500
Epoch: 0010 train_loss= 2.47787 train/struct_loss= 7.68937 train/feat_loss= 1.17500
Epoch: 0010 Auc 0.8133027862351705
Epo

Epoch: 0094 train_loss= 2.47205 train/struct_loss= 7.66024 train/feat_loss= 1.17500
Epoch: 0095 train_loss= 2.47147 train/struct_loss= 7.65734 train/feat_loss= 1.17500
Epoch: 0096 train_loss= 2.47184 train/struct_loss= 7.65922 train/feat_loss= 1.17500
Epoch: 0097 train_loss= 2.47201 train/struct_loss= 7.66006 train/feat_loss= 1.17500
Epoch: 0098 train_loss= 2.47163 train/struct_loss= 7.65815 train/feat_loss= 1.17500
Epoch: 0099 train_loss= 2.47213 train/struct_loss= 7.66067 train/feat_loss= 1.17500
Epoch: 0100 train_loss= 2.47139 train/struct_loss= 7.65697 train/feat_loss= 1.17500
Epoch: 0100 Auc 0.8138474545150602
Epoch: 0101 train_loss= 2.47151 train/struct_loss= 7.65755 train/feat_loss= 1.17500
Epoch: 0102 train_loss= 2.47147 train/struct_loss= 7.65737 train/feat_loss= 1.17500
Epoch: 0103 train_loss= 2.47169 train/struct_loss= 7.65843 train/feat_loss= 1.17500
Epoch: 0104 train_loss= 2.47185 train/struct_loss= 7.65927 train/feat_loss= 1.17500
Epoch: 0105 train_loss= 2.47138 train/str

Epoch: 0188 train_loss= 2.47217 train/struct_loss= 7.66084 train/feat_loss= 1.17500
Epoch: 0189 train_loss= 2.47180 train/struct_loss= 7.65902 train/feat_loss= 1.17500
Epoch: 0190 train_loss= 2.47183 train/struct_loss= 7.65915 train/feat_loss= 1.17500
Epoch: 0190 Auc 0.8138597866270577
Epoch: 0191 train_loss= 2.47165 train/struct_loss= 7.65827 train/feat_loss= 1.17500
Epoch: 0192 train_loss= 2.47161 train/struct_loss= 7.65805 train/feat_loss= 1.17500
Epoch: 0193 train_loss= 2.47175 train/struct_loss= 7.65877 train/feat_loss= 1.17500
Epoch: 0194 train_loss= 2.47164 train/struct_loss= 7.65820 train/feat_loss= 1.17500
Epoch: 0195 train_loss= 2.47195 train/struct_loss= 7.65976 train/feat_loss= 1.17500
Epoch: 0196 train_loss= 2.47151 train/struct_loss= 7.65757 train/feat_loss= 1.17500
Epoch: 0197 train_loss= 2.47179 train/struct_loss= 7.65893 train/feat_loss= 1.17500
Epoch: 0198 train_loss= 2.47197 train/struct_loss= 7.65988 train/feat_loss= 1.17500
Epoch: 0199 train_loss= 2.47204 train/str

Epoch: 0282 train_loss= 2.47213 train/struct_loss= 7.66068 train/feat_loss= 1.17500
Epoch: 0283 train_loss= 2.47224 train/struct_loss= 7.66119 train/feat_loss= 1.17500
Epoch: 0284 train_loss= 2.47227 train/struct_loss= 7.66137 train/feat_loss= 1.17500
Epoch: 0285 train_loss= 2.47183 train/struct_loss= 7.65918 train/feat_loss= 1.17500
Epoch: 0286 train_loss= 2.47224 train/struct_loss= 7.66123 train/feat_loss= 1.17500
Epoch: 0287 train_loss= 2.47175 train/struct_loss= 7.65877 train/feat_loss= 1.17500
Epoch: 0288 train_loss= 2.47148 train/struct_loss= 7.65742 train/feat_loss= 1.17500
Epoch: 0289 train_loss= 2.47158 train/struct_loss= 7.65793 train/feat_loss= 1.17500
Epoch: 0290 train_loss= 2.47206 train/struct_loss= 7.66031 train/feat_loss= 1.17500
Epoch: 0290 Auc 0.8138412884590616
Epoch: 0291 train_loss= 2.47155 train/struct_loss= 7.65777 train/feat_loss= 1.17500
Epoch: 0292 train_loss= 2.47176 train/struct_loss= 7.65882 train/feat_loss= 1.17500
Epoch: 0293 train_loss= 2.47141 train/str

Epoch: 0376 train_loss= 2.47150 train/struct_loss= 7.65752 train/feat_loss= 1.17500
Epoch: 0377 train_loss= 2.47175 train/struct_loss= 7.65876 train/feat_loss= 1.17500
Epoch: 0378 train_loss= 2.47133 train/struct_loss= 7.65664 train/feat_loss= 1.17500
Epoch: 0379 train_loss= 2.47205 train/struct_loss= 7.66025 train/feat_loss= 1.17500
Epoch: 0380 train_loss= 2.47189 train/struct_loss= 7.65947 train/feat_loss= 1.17500
Epoch: 0380 Auc 0.8138522503363925
Epoch: 0381 train_loss= 2.47227 train/struct_loss= 7.66134 train/feat_loss= 1.17500
Epoch: 0382 train_loss= 2.47151 train/struct_loss= 7.65757 train/feat_loss= 1.17500
Epoch: 0383 train_loss= 2.47147 train/struct_loss= 7.65736 train/feat_loss= 1.17500
Epoch: 0384 train_loss= 2.47191 train/struct_loss= 7.65957 train/feat_loss= 1.17500
Epoch: 0385 train_loss= 2.47161 train/struct_loss= 7.65805 train/feat_loss= 1.17500
Epoch: 0386 train_loss= 2.47198 train/struct_loss= 7.65989 train/feat_loss= 1.17500
Epoch: 0387 train_loss= 2.47192 train/str

Epoch: 0470 train_loss= 2.47165 train/struct_loss= 7.65825 train/feat_loss= 1.17500
Epoch: 0470 Auc 0.81378305348574
Epoch: 0471 train_loss= 2.47162 train/struct_loss= 7.65810 train/feat_loss= 1.17500
Epoch: 0472 train_loss= 2.47139 train/struct_loss= 7.65695 train/feat_loss= 1.17500
Epoch: 0473 train_loss= 2.47203 train/struct_loss= 7.66014 train/feat_loss= 1.17500
Epoch: 0474 train_loss= 2.47172 train/struct_loss= 7.65860 train/feat_loss= 1.17500
Epoch: 0475 train_loss= 2.47186 train/struct_loss= 7.65932 train/feat_loss= 1.17500
Epoch: 0476 train_loss= 2.47227 train/struct_loss= 7.66135 train/feat_loss= 1.17500
Epoch: 0477 train_loss= 2.47215 train/struct_loss= 7.66078 train/feat_loss= 1.17500
Epoch: 0478 train_loss= 2.47213 train/struct_loss= 7.66068 train/feat_loss= 1.17500
Epoch: 0479 train_loss= 2.47181 train/struct_loss= 7.65906 train/feat_loss= 1.17500
Epoch: 0480 train_loss= 2.47212 train/struct_loss= 7.66062 train/feat_loss= 1.17500
Epoch: 0480 Auc 0.8138214200563989
Epoch: 0

In [9]:
train_dominant(dataset="ACM", hidden_dim=64, max_epoch=100, lr=5e-3, dropout=0.3, alpha=0.8, device="mps")

Epoch: 0000 train_loss= 5.35934 train/struct_loss= 23.50783 train/feat_loss= 0.82222
Epoch: 0000 Auc 0.8080926231628126
Epoch: 0001 train_loss= 3.13666 train/struct_loss= 12.82364 train/feat_loss= 0.71492
Epoch: 0002 train_loss= 1.81711 train/struct_loss= 6.31251 train/feat_loss= 0.69326
Epoch: 0003 train_loss= 1.30341 train/struct_loss= 3.74537 train/feat_loss= 0.69292
Epoch: 0004 train_loss= 1.24686 train/struct_loss= 3.46262 train/feat_loss= 0.69292
Epoch: 0005 train_loss= 1.24628 train/struct_loss= 3.45979 train/feat_loss= 0.69290
Epoch: 0006 train_loss= 1.24634 train/struct_loss= 3.46011 train/feat_loss= 0.69290
Epoch: 0007 train_loss= 1.24630 train/struct_loss= 3.45987 train/feat_loss= 0.69290
Epoch: 0008 train_loss= 1.24670 train/struct_loss= 3.46184 train/feat_loss= 0.69292
Epoch: 0009 train_loss= 1.24629 train/struct_loss= 3.45984 train/feat_loss= 0.69290
Epoch: 0010 train_loss= 1.24640 train/struct_loss= 3.46042 train/feat_loss= 0.69290
Epoch: 0010 Auc 0.896741423067584
Epoch

Epoch: 0094 train_loss= 1.24650 train/struct_loss= 3.46089 train/feat_loss= 0.69290
Epoch: 0095 train_loss= 1.24650 train/struct_loss= 3.46089 train/feat_loss= 0.69290
Epoch: 0096 train_loss= 1.24650 train/struct_loss= 3.46089 train/feat_loss= 0.69290
Epoch: 0097 train_loss= 1.24650 train/struct_loss= 3.46089 train/feat_loss= 0.69290
Epoch: 0098 train_loss= 1.24650 train/struct_loss= 3.46089 train/feat_loss= 0.69290
Epoch: 0099 train_loss= 1.24650 train/struct_loss= 3.46089 train/feat_loss= 0.69290
Epoch: 0099 Auc 0.8967425301324609


In [10]:
train_dominant(dataset="Flickr", hidden_dim=64, max_epoch=100, lr=5e-3, dropout=0.3, alpha=0.8, device="mps")

Epoch: 0000 train_loss= 5.35012 train/struct_loss= 21.66717 train/feat_loss= 1.27086
Epoch: 0000 Auc 0.7789570890524292
Epoch: 0001 train_loss= 3.87508 train/struct_loss= 14.59658 train/feat_loss= 1.19470
Epoch: 0002 train_loss= 3.04225 train/struct_loss= 10.48366 train/feat_loss= 1.18189
Epoch: 0003 train_loss= 2.56934 train/struct_loss= 8.12093 train/feat_loss= 1.18144
Epoch: 0004 train_loss= 2.33280 train/struct_loss= 6.93812 train/feat_loss= 1.18147
Epoch: 0005 train_loss= 2.24299 train/struct_loss= 6.48912 train/feat_loss= 1.18145
Epoch: 0006 train_loss= 2.21304 train/struct_loss= 6.33940 train/feat_loss= 1.18145
Epoch: 0007 train_loss= 2.23252 train/struct_loss= 6.43676 train/feat_loss= 1.18146
Epoch: 0008 train_loss= 2.21936 train/struct_loss= 6.37107 train/feat_loss= 1.18143
Epoch: 0009 train_loss= 2.20474 train/struct_loss= 6.29794 train/feat_loss= 1.18145
Epoch: 0010 train_loss= 2.20720 train/struct_loss= 6.31020 train/feat_loss= 1.18146
Epoch: 0010 Auc 0.7940296578785636
Epo

Epoch: 0094 train_loss= 2.20739 train/struct_loss= 6.31122 train/feat_loss= 1.18143
Epoch: 0095 train_loss= 2.20686 train/struct_loss= 6.30859 train/feat_loss= 1.18143
Epoch: 0096 train_loss= 2.20857 train/struct_loss= 6.31716 train/feat_loss= 1.18143
Epoch: 0097 train_loss= 2.20833 train/struct_loss= 6.31597 train/feat_loss= 1.18143
Epoch: 0098 train_loss= 2.20804 train/struct_loss= 6.31451 train/feat_loss= 1.18143
Epoch: 0099 train_loss= 2.20702 train/struct_loss= 6.30939 train/feat_loss= 1.18143
Epoch: 0099 Auc 0.7941618733945821


In [9]:

data_mat = sio.loadmat(f'data/Amazon.mat')

In [10]:
data_mat

{'__header__': b'MATLAB 5.0 MAT-file, Platform: PCWIN64, Created on: Mon Aug 17 19:18:11 2015',
 '__version__': '1.0',
 '__globals__': [],
 'gnd': array([[0],
        [0],
        [0],
        ...,
        [0],
        [0],
        [0]], dtype=uint8),
 'X': array([[2.06610e-02, 3.33333e-01, 1.90000e-05, ..., 1.30400e-03,
         0.00000e+00, 0.00000e+00],
        [0.00000e+00, 1.22449e-01, 1.22000e-04, ..., 4.95500e-03,
         3.06122e-01, 2.24490e-01],
        [0.00000e+00, 0.00000e+00, 1.36000e-04, ..., 5.22000e-04,
         1.00000e+00, 0.00000e+00],
        ...,
        [0.00000e+00, 0.00000e+00, 1.00000e-05, ..., 2.61000e-04,
         1.00000e+00, 0.00000e+00],
        [0.00000e+00, 2.00000e-01, 1.00000e-05, ..., 1.04300e-03,
         4.00000e-01, 0.00000e+00],
        [1.23970e-02, 4.54550e-02, 1.00000e-05, ..., 6.25900e-03,
         6.81818e-01, 0.00000e+00]]),
 'A': array([[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,

In [40]:
adj = data_mat['Network']
feat = data_mat['Attributes']
truth = data_mat['Label']
truth = truth.flatten()

array([0, 0, 0, ..., 0, 0, 0], dtype=uint8)

In [41]:
np.unique(truth, return_counts=True)

(array([0, 1], dtype=uint8), array([4898,  298]))

In [43]:
adj_norm = normalize_adj(adj + sp.eye(adj.shape[0]))


<5196x5196 sparse matrix of type '<class 'numpy.float64'>'
	with 350577 stored elements in COOrdinate format>

In [46]:
adj_norm.toarray()

array([[0.0012987 , 0.00156096, 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.00156096, 0.00187617, 0.00195477, ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.00195477, 0.00203666, ..., 0.        , 0.        ,
        0.        ],
       ...,
       [0.        , 0.        , 0.        , ..., 0.0625    , 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.05263158,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.05263158]])

In [47]:
adj_norm = normalize_adj(adj + sp.eye(adj.shape[0]))
adj_norm = adj_norm.toarray()
adj = adj + sp.eye(adj.shape[0])
adj = adj.toarray()
feat = feat.toarray()

In [49]:
adj_norm

array([[0.0012987 , 0.00156096, 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.00156096, 0.00187617, 0.00195477, ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.00195477, 0.00203666, ..., 0.        , 0.        ,
        0.        ],
       ...,
       [0.        , 0.        , 0.        , ..., 0.0625    , 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.05263158,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.05263158]])

In [67]:
adj_norm[0][2]

0.0