In [1]:
import networkx as nx
import pandas as pd
import numpy as np
from utils import load_data,normalize,doublerelu
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module
torch.set_printoptions(sci_mode=False)
import time

In [2]:
import warnings
warnings.filterwarnings("ignore") 

In [3]:
cuda = torch.cuda.is_available()
weight_decay = 10e-4
epochs = 10001
seed = 165
hidden = 10
lr = 0.0001

In [4]:
np.random.seed(seed)
torch.manual_seed(seed)
if cuda:
    torch.cuda.manual_seed(seed)

In [5]:
class GNN1Layer(Module):

    def __init__(self, batch_size, in_features, out_features):
        super(GNN1Layer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.batch_size = batch_size

        weight1_eye = torch.FloatTensor(torch.eye(in_features, out_features))
        weight1_eye = weight1_eye.reshape((1, in_features, out_features))
        weight1_eye = weight1_eye.repeat(batch_size, 1, 1)
        self.weight1 = Parameter(weight1_eye)
        self.weight2 = Parameter(torch.zeros(batch_size, in_features, out_features))

    def forward(self, input, adj):
        v1 = torch.bmm(input, self.weight1)
        v2 = torch.bmm(torch.bmm(adj, input), self.weight2)
        output = v1 + v2
        return output

In [6]:
class GNN2Layer(Module):

    def __init__(self, batch_size, in_features, out_features):
        super(GNN2Layer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.batch_size = batch_size
        weight1_eye = torch.FloatTensor(torch.eye(in_features, out_features))
        weight1_eye = weight1_eye.reshape((1, in_features, out_features))
        weight1_eye = weight1_eye.repeat(batch_size, 1, 1)
        weight1_rand = torch.empty(batch_size,in_features,out_features-in_features)
        torch.nn.init.xavier_uniform_(weight1_rand, gain=1.0)
        self.weight1 = Parameter(weight1_eye)
        self.weight2 = Parameter(torch.zeros(batch_size, in_features, out_features))

    def forward(self, input, adj):
        v1 = torch.bmm(input, self.weight1)
        v2 = torch.bmm(torch.bmm(adj, input), self.weight2)
        output = v1 + v2
        return output

In [7]:
class GNN1(nn.Module):

    def __init__(self, batch_size, nfeat, ndim, hidden):
        super(GNN1, self).__init__()

        self.gc1 = GNN1Layer(batch_size, ndim, ndim)

    def forward(self, x, adj, random_indices):
        f = torch.clone(x)
        x = doublerelu(self.gc1(x, adj))
        x = x/x.sum(axis=2).unsqueeze(2) #normalize st sum = 1

        f[0][random_indices, :] = x[0][random_indices, :]
        
        return f

In [8]:
class GNN2(nn.Module):

    def __init__(self, batch_size, nfeat, ndim, hidden):
        super(GNN2, self).__init__()

        self.gc1 = GNN2Layer(batch_size, ndim, hidden)
        self.gc2 = GNN1Layer(batch_size, hidden, ndim)

    def forward(self, x, adj):
        x = doublerelu(self.gc1(x, adj))
        x = doublerelu(self.gc2(x, adj))
        x = x/x.sum(axis=2).unsqueeze(2) #normalize st sum = 1
        return x

In [9]:
def train(adj,features,labels,random_indices):
    
    adj_norm = normalize(adj)
    
    labels = labels - 1
    
    adj = torch.FloatTensor(adj)
    adj_norm = torch.FloatTensor(adj_norm)
    features = torch.FloatTensor(features)
    labels = torch.FloatTensor(labels)
    
    model = GNN1(batch_size=adj.shape[0],
                nfeat=adj.shape[1],
                ndim=nb_label,
                hidden=hidden)
    if cuda:
        model.cuda()
        features = features.cuda()
        adj = adj.cuda()
        adj_norm = adj_norm.cuda()
        labels = labels.cuda()
    
    # Train model
    t_total = time.time()

    optimizer = optim.Adam(model.parameters(),
                           lr=lr, weight_decay=weight_decay)
    
    criterion = nn.CrossEntropyLoss()
    
    for epoch in range(epochs):

        t = time.time()
        model.train()
        optimizer.zero_grad()

        output = model(features, adj_norm, random_indices)
            
        accuracy = torch.sum(torch.argmax(output,axis=2)==labels.reshape(1,-1))/labels.shape[0]
        
        loss = criterion(output[0],labels.reshape(-1).long())

        loss.backward(retain_graph=True)

        optimizer.step()

        if epoch == 0:
            best_loss = loss
            best_output = output
            best_acc = accuracy
        else:
            if loss < best_loss:
                best_loss = loss
                best_output = output
                best_acc = accuracy

        if epoch % 1000 == 0:
            print('Epoch: {:04d}'.format(epoch + 1),
                  'Accuracy: {:.4f}'.format(best_acc.item()),
                  'Loss: {:.8f}'.format(best_loss.item()),
                  'time: {:.4f}s'.format(time.time() - t))
            
    print("Optimization Finished!")
    print("Total time elapsed: {:.4f}s".format(time.time() - t_total))
    
    return best_loss,best_output

In [10]:
adj,feature,labels = load_data()

feature = feature - 1
nb_label = int(max(feature))+1
featuress = np.eye(nb_label)[np.array(feature,dtype=int).reshape(1,-1)]

In [None]:
mask_percentage = [0.3,0.5,0.7]
for m in mask_percentage:
    
    features = np.copy(featuress)
    # Masking
    number_of_rows = features[0].shape[0]
    random_indices = np.random.choice(number_of_rows, size=int(m*number_of_rows), replace=False)
    random_rows = features[0][random_indices, :]
    features[0][random_indices, :] = np.tile(np.array([[0.2]]),random_rows.shape)
    
    
    print("\nMasked {}% of nodes\n".format(int(m*100)))
    prev_loss, op = train(adj,features,labels, random_indices)
    #print(op)
    loss, op = train(adj,op.cpu().detach().numpy(),labels, random_indices)
    while loss < prev_loss :
        prev_loss = loss
        loss, op = train(adj,op.cpu().detach().numpy(),labels, random_indices)


Masked 30% of nodes

Epoch: 0001 Accuracy: 0.6625 Loss: 1.19313920 time: 0.3581s
Epoch: 1001 Accuracy: 0.7339 Loss: 1.16949165 time: 0.0030s
Epoch: 2001 Accuracy: 0.7339 Loss: 1.15972137 time: 0.0020s
Epoch: 3001 Accuracy: 0.7334 Loss: 1.15877306 time: 0.0020s
Epoch: 4001 Accuracy: 0.7353 Loss: 1.15721190 time: 0.0020s
Epoch: 5001 Accuracy: 0.7473 Loss: 1.15168619 time: 0.0020s
Epoch: 6001 Accuracy: 0.7663 Loss: 1.14059055 time: 0.0010s
Epoch: 7001 Accuracy: 0.7663 Loss: 1.14059055 time: 0.0020s
Epoch: 8001 Accuracy: 0.7663 Loss: 1.13737237 time: 0.0020s
Epoch: 9001 Accuracy: 0.7659 Loss: 1.13683629 time: 0.0020s
Epoch: 10001 Accuracy: 0.7659 Loss: 1.13683629 time: 0.0020s
Optimization Finished!
Total time elapsed: 19.3150s
Epoch: 0001 Accuracy: 0.7659 Loss: 1.13683629 time: 0.0030s
Epoch: 1001 Accuracy: 0.7650 Loss: 1.13554573 time: 0.0020s
Epoch: 2001 Accuracy: 0.7673 Loss: 1.13509429 time: 0.0020s
Epoch: 3001 Accuracy: 0.7673 Loss: 1.13455796 time: 0.0010s
Epoch: 4001 Accuracy: 0.7

Epoch: 6001 Accuracy: 0.7775 Loss: 1.12537718 time: 0.0010s
Epoch: 7001 Accuracy: 0.7775 Loss: 1.12537718 time: 0.0020s
Epoch: 8001 Accuracy: 0.7775 Loss: 1.12537718 time: 0.0020s
Epoch: 9001 Accuracy: 0.7775 Loss: 1.12537718 time: 0.0020s
Epoch: 10001 Accuracy: 0.7775 Loss: 1.12537718 time: 0.0020s
Optimization Finished!
Total time elapsed: 18.6304s
Epoch: 0001 Accuracy: 0.7775 Loss: 1.12537718 time: 0.0020s
Epoch: 1001 Accuracy: 0.7784 Loss: 1.12510467 time: 0.0010s
Epoch: 2001 Accuracy: 0.7784 Loss: 1.12510467 time: 0.0020s
Epoch: 3001 Accuracy: 0.7784 Loss: 1.12510467 time: 0.0020s
Epoch: 4001 Accuracy: 0.7784 Loss: 1.12510467 time: 0.0020s
Epoch: 5001 Accuracy: 0.7784 Loss: 1.12510467 time: 0.0020s
Epoch: 6001 Accuracy: 0.7784 Loss: 1.12510467 time: 0.0020s
Epoch: 7001 Accuracy: 0.7784 Loss: 1.12510467 time: 0.0010s
Epoch: 8001 Accuracy: 0.7784 Loss: 1.12510467 time: 0.0033s
Epoch: 9001 Accuracy: 0.7784 Loss: 1.12510467 time: 0.0030s
Epoch: 10001 Accuracy: 0.7784 Loss: 1.12510467 

Epoch: 1001 Accuracy: 0.7784 Loss: 1.12461066 time: 0.0020s
Epoch: 2001 Accuracy: 0.7784 Loss: 1.12461019 time: 0.0030s
Epoch: 3001 Accuracy: 0.7784 Loss: 1.12460959 time: 0.0030s
Epoch: 4001 Accuracy: 0.7784 Loss: 1.12460911 time: 0.0020s
Epoch: 5001 Accuracy: 0.7784 Loss: 1.12460852 time: 0.0020s
Epoch: 6001 Accuracy: 0.7784 Loss: 1.12460744 time: 0.0020s
Epoch: 7001 Accuracy: 0.7784 Loss: 1.12460601 time: 0.0020s
Epoch: 8001 Accuracy: 0.7784 Loss: 1.12460351 time: 0.0020s
Epoch: 9001 Accuracy: 0.7779 Loss: 1.12459970 time: 0.0020s
Epoch: 10001 Accuracy: 0.7779 Loss: 1.12459123 time: 0.0010s
Optimization Finished!
Total time elapsed: 20.4964s
Epoch: 0001 Accuracy: 0.7779 Loss: 1.12459123 time: 0.0030s
Epoch: 1001 Accuracy: 0.7779 Loss: 1.12459028 time: 0.0020s
Epoch: 2001 Accuracy: 0.7779 Loss: 1.12458897 time: 0.0020s
Epoch: 3001 Accuracy: 0.7779 Loss: 1.12458789 time: 0.0020s
Epoch: 4001 Accuracy: 0.7779 Loss: 1.12458682 time: 0.0020s
Epoch: 5001 Accuracy: 0.7779 Loss: 1.12458599 t

Epoch: 7001 Accuracy: 0.7761 Loss: 1.12418497 time: 0.0020s
Epoch: 8001 Accuracy: 0.7761 Loss: 1.12418497 time: 0.0010s
Epoch: 9001 Accuracy: 0.7761 Loss: 1.12418497 time: 0.0020s
Epoch: 10001 Accuracy: 0.7761 Loss: 1.12418497 time: 0.0020s
Optimization Finished!
Total time elapsed: 19.7440s
Epoch: 0001 Accuracy: 0.7761 Loss: 1.12418497 time: 0.0020s
Epoch: 1001 Accuracy: 0.7761 Loss: 1.12417924 time: 0.0030s
Epoch: 2001 Accuracy: 0.7761 Loss: 1.12417161 time: 0.0020s
Epoch: 3001 Accuracy: 0.7761 Loss: 1.12416184 time: 0.0030s
Epoch: 4001 Accuracy: 0.7765 Loss: 1.12411761 time: 0.0020s
Epoch: 5001 Accuracy: 0.7765 Loss: 1.12411761 time: 0.0020s
Epoch: 6001 Accuracy: 0.7765 Loss: 1.12411761 time: 0.0020s
Epoch: 7001 Accuracy: 0.7765 Loss: 1.12411761 time: 0.0020s
Epoch: 8001 Accuracy: 0.7765 Loss: 1.12411761 time: 0.0020s
Epoch: 9001 Accuracy: 0.7765 Loss: 1.12411761 time: 0.0030s
Epoch: 10001 Accuracy: 0.7765 Loss: 1.12411761 time: 0.0020s
Optimization Finished!
Total time elapsed: 21.

Epoch: 1001 Accuracy: 0.7779 Loss: 1.12376416 time: 0.0020s
Epoch: 2001 Accuracy: 0.7779 Loss: 1.12376392 time: 0.0010s
Epoch: 3001 Accuracy: 0.7779 Loss: 1.12376392 time: 0.0010s
Epoch: 4001 Accuracy: 0.7779 Loss: 1.12376380 time: 0.0020s
Epoch: 5001 Accuracy: 0.7779 Loss: 1.12376368 time: 0.0010s
Epoch: 6001 Accuracy: 0.7779 Loss: 1.12376356 time: 0.0020s
Epoch: 7001 Accuracy: 0.7779 Loss: 1.12376332 time: 0.0020s
Epoch: 8001 Accuracy: 0.7779 Loss: 1.12376273 time: 0.0020s
Epoch: 9001 Accuracy: 0.7779 Loss: 1.12376189 time: 0.0020s
Epoch: 10001 Accuracy: 0.7779 Loss: 1.12376022 time: 0.0030s
Optimization Finished!
Total time elapsed: 18.8257s
Epoch: 0001 Accuracy: 0.7779 Loss: 1.12376022 time: 0.0020s
Epoch: 1001 Accuracy: 0.7779 Loss: 1.12375987 time: 0.0030s
Epoch: 2001 Accuracy: 0.7779 Loss: 1.12375975 time: 0.0020s
Epoch: 3001 Accuracy: 0.7779 Loss: 1.12375975 time: 0.0030s
Epoch: 4001 Accuracy: 0.7779 Loss: 1.12375975 time: 0.0030s
Epoch: 5001 Accuracy: 0.7779 Loss: 1.12375963 t