In [39]:
import torch
import torch.nn as nn
import numpy as np
from torch.nn import init
from random import shuffle, randint
import torch.nn.functional as F
from torch_geometric.datasets import Reddit, PPI, Planetoid
from itertools import combinations, combinations_with_replacement
from sklearn.metrics import f1_score, accuracy_score
from sklearn.decomposition import TruncatedSVD
from sklearn.manifold import SpectralEmbedding
import itertools
import time

In [40]:
# Get paths to datasets
PATH_TO_DATASETS_DIRECTORY = './'
datasets = {
    'reddit': Reddit(root=PATH_TO_DATASETS_DIRECTORY + '/datasets/Reddit'),
    'cora' : Planetoid(root=PATH_TO_DATASETS_DIRECTORY + '/datasets/Cora/', name='Cora'),
    'citeseer' : Planetoid(root=PATH_TO_DATASETS_DIRECTORY + '/datasets/CiteSeer/', name='CiteSeer'),
    'pubmed' : Planetoid(root=PATH_TO_DATASETS_DIRECTORY + '/datasets/PubMed/', name='PubMed'),
}

In [41]:
# Obtain dataset
DATASET = 'cora'
PREDICTION = 'node'
dataset = datasets[DATASET]
data = dataset[0]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

predictions = {
    'node' : dataset.num_classes,
    'link' : 2,
    'triad' : 4,
}

# Get train, val, test data
data.train_mask = ~data.val_mask * ~data.test_mask

adj_mat = torch.zeros((data.num_nodes,data.num_nodes))
edges = data.edge_index.t()
adj_mat[edges[:,0], edges[:,1]] = 1

# Non-overlapping induced subgraphs
adj_train = adj_mat[data.train_mask].t()[data.train_mask].t()
adj_validation = adj_mat[data.val_mask].t()[data.val_mask].t()
adj_test = adj_mat[data.test_mask].t()[data.test_mask].t()

In [42]:
# Edge/non-edge corruption
def corrupt_adj(adj_mat, task, percent=2):
    """ Returns the corrupted version of the adjacency matrix """
    if task == 'link':
        edges = adj_mat.triu().nonzero()
        num_edges = edges.shape[0]
        num_to_corrupt = int(percent/100.0 * num_edges)
        random_corruption = np.random.randint(num_edges, size=num_to_corrupt)
        adj_mat_corrupted = adj_mat.clone()
        false_edges, false_non_edges = [], []
        #Edge Corruption
        for ed in edges[random_corruption]:
            adj_mat_corrupted[ed[0], ed[1]] = 0
            adj_mat_corrupted[ed[1], ed[0]] = 0
            false_non_edges.append(ed.tolist())
        #Non Edge Corruption
        random_non_edge_corruption = list(np.random.randint(adj_mat.shape[0], size = 6*num_to_corrupt))
        non_edge_to_corrupt = []
        for k in range(len(random_non_edge_corruption)-1):
            to_check = [random_non_edge_corruption[k], random_non_edge_corruption[k+1]]
            if to_check not in edges.tolist():
                non_edge_to_corrupt.append(to_check)
            if len(non_edge_to_corrupt) == num_to_corrupt:
                break
        non_edge_to_corrupt = torch.Tensor(non_edge_to_corrupt).type(torch.int16)
        for n_ed in non_edge_to_corrupt:
            adj_mat_corrupted[n_ed[0], n_ed[1]] = 1
            adj_mat_corrupted[n_ed[1], n_ed[0]] = 1
            false_edges.append(n_ed.tolist())
    return adj_mat_corrupted, false_edges, false_non_edges

In [43]:
# Supervised learning network
num_neurons = 256
input_rep = num_neurons + data.num_features

class StructMLP(nn.Module):
    """
        Compute an estimate of the expected value of a function of node embeddings
        Permutation Invariant Function - Deepsets - Zaheer, et al.
    """
    def __init__(self, node_set_size=1):
        super(StructMLP, self).__init__()

        self.node_set_size = node_set_size
        #Deepsets MLP

        self.ds_layer_1 = nn.Linear(input_rep*node_set_size, num_neurons)
        self.ds_layer_2 = nn.Linear(num_neurons, num_neurons)

        #One Hidden Layer
        self.layer1 = nn.Linear(num_neurons, num_neurons)
        self.layer2 = nn.Linear(num_neurons, predictions[PREDICTION])
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, input_tensor):
        #Deepsets initially on each of the samples
        num_nodes = input_tensor.shape[1]
        comb_tensor = torch.LongTensor(list(combinations(range(num_nodes), self.node_set_size)))
        sum_tensor = torch.zeros(comb_tensor.shape[0], num_neurons).to(device)

        for i in range(input_tensor.shape[0]):
            #Process the input tensor to form n choose k combinations and create a zero tensor
            set_init_rep = input_tensor[i][comb_tensor].view(comb_tensor.shape[0],-1)

            x = self.ds_layer_1(set_init_rep)
            x = self.relu(x)
            x = self.ds_layer_2(x)
            sum_tensor += x

        x = sum_tensor / input_tensor.shape[0]

        #One Hidden Layer for predictor
        x = self.layer1(x)
        x = self.relu(x)
        x = self.layer2(x)
        return x

    def compute_loss(self, input_tensor, target):
        pred = self.forward(input_tensor)
        return F.cross_entropy(pred, target)

if PREDICTION == 'node':
    node_set_size = 1
elif PREDICTION == 'link':
    node_set_size = 2
else:
    node_set_size = 3

if PREDICTION == 'node':
    target_train = data.y[data.train_mask].type(torch.long)
    target_val = data.y[data.val_mask].type(torch.long)
    target_test = data.y[data.test_mask].type(torch.long)

In [44]:
def sampleZ(adj, ns, ni):
    numbers = list(np.random.randint(500, size=ns))
    hidden_samples = []
    for number in numbers :
        svd = TruncatedSVD(n_components=256, n_iter=ni, random_state=number)
        u = svd.fit_transform(adj)
        hidden_samples.append(torch.Tensor(u).to(device))
    return hidden_samples

# NOTE: Remember to set num_neurons as dim(hidden_sample)

In [45]:
# Here's the magic:
# Parameters for t-SVD sampling
RUN_COUNT = 12
NUM_SAMPLES_LIST = [5]
NUM_ITERS_LIST   = [10]
PERCENT_LIST     = [1, 2, 5] 

result_dict = {}

for NUM_SAMPLES, NUM_ITERS, PERCENT in itertools.product(NUM_SAMPLES_LIST, NUM_ITERS_LIST, PERCENT_LIST):
    
    print("\n{} sample(s) of {} iteration(s) of T-SVD at {}% corruption".format(NUM_SAMPLES, NUM_ITERS, PERCENT))
    results = []
    
    for run in range(RUN_COUNT):
        print(" - Test {}".format(run+1))
        start = time.time()

        # Corrupt edges
        adj_train_corrupted, train_false_edges, train_false_non_edges = corrupt_adj(adj_train, 'link', percent=PERCENT)
        adj_val_corrupted, val_false_edges, val_false_non_edges = corrupt_adj(adj_validation, 'link', percent=PERCENT)
        adj_test_corrupted, test_false_edges, test_false_non_edges  = corrupt_adj(adj_test, 'link', percent=PERCENT)
        
        # Create a new model
        mlp = StructMLP(node_set_size).to(device)
        mlp_optimizer = torch.optim.Adam(mlp.parameters(), lr=0.001)
        mlp_model = 'best_mlp_model.model'

        # Training with current sampling params
#         epochs = 50
        epochs = 50
        validation_loss = 10000.0
        for num_epoch in range(epochs):
#             print(num_epoch)
            mlp_optimizer.zero_grad()
            target = target_train.to(device)
#             numbers = list(np.random.randint(500, size=NUM_SAMPLES))
            hidden_samples_train = sampleZ(adj_train_corrupted, NUM_SAMPLES, NUM_ITERS)
#             for number in numbers :
#                 svd = TruncatedSVD(n_components=256, n_iter=NUM_ITERS, random_state=number)
#                 u_train = svd.fit_transform(adj_train_corrupted)
#                 hidden_samples_train.append(torch.Tensor(u_train).to(device))
            for i in range(NUM_SAMPLES):
                hidden_samples_train[i] = torch.cat((hidden_samples_train[i].to(device), data.x[data.train_mask].to(device)),1)
            input_ = torch.stack(hidden_samples_train)
            input_ = input_.detach()
            loss = mlp.compute_loss(input_, target)
            #print("Training Loss: ", loss.item())
            with torch.no_grad():
                #Do Validation and check if validation loss has gone down
#                 numbers = list(np.random.randint(500, size=NUM_SAMPLES))
                hidden_samples_validation = sampleZ(adj_val_corrupted, NUM_SAMPLES, NUM_ITERS)
#                 for number in numbers :
#                     svd = TruncatedSVD(n_components=256, n_iter=NUM_ITERS, random_state=number)
#                     u_validation = svd.fit_transform(adj_val_corrupted)
#                     hidden_samples_validation.append(torch.Tensor(u_validation).to(device))
                for i in range(NUM_SAMPLES):
                    hidden_samples_validation[i] = torch.cat((hidden_samples_validation[i].to(device), data.x[data.val_mask].to(device)),1)
                input_val = torch.stack(hidden_samples_validation)
                input_val = input_val.detach()
                compute_val_loss = mlp.compute_loss(input_val, target_val.to(device))
                if compute_val_loss < validation_loss:
                    validation_loss = compute_val_loss
                    #print("Validation Loss: ", validation_loss)
                    #Save Model
                    torch.save(mlp.state_dict(), mlp_model)
            loss.backward()
            mlp_optimizer.step()
        
        #print(" -- Best val loss:     {}".format(validation_loss))
        
        # Load best model
        mlp = StructMLP(node_set_size).to(device)
        mlp.load_state_dict(torch.load(mlp_model))
        
        # Forward pass on test set
#         numbers = list(np.random.randint(500, size=NUM_SAMPLES))
        hidden_samples_test = sampleZ(adj_test_corrupted, NUM_SAMPLES, NUM_ITERS)
#         for number in numbers :
#             svd = TruncatedSVD(n_components=256, n_iter=NUM_ITERS, random_state=number)
#             u_test = svd.fit_transform(adj_test_corrupted)
#             hidden_samples_test.append(torch.Tensor(u_test).to(device))
        for i in range(NUM_SAMPLES):
            hidden_samples_test[i] = torch.cat((hidden_samples_test[i].to(device), data.x[data.test_mask].to(device)),1)
        t_test = target_test.to("cpu").numpy()
        input_test = torch.stack(hidden_samples_test)
        input_test = input_test.detach()

        with torch.no_grad():
            test_pred = mlp.forward(input_test)
            pred = F.log_softmax(test_pred, dim=1)

        pred = pred.detach().to("cpu").numpy()
        pred = np.argmax(pred, axis=1)
        
        # Obtain results for run
        mf1 = f1_score(t_test, pred, average='micro')
        wf1 = f1_score(t_test, pred, average='weighted')
        results.append([mf1, wf1])
        #print(" -- Micro F1 Score:    {}".format(mf1))
        #print(" -- Weighted F1 Score: {}".format(wf1))
        print(" -- MF1, WF1, BVL: {:.3f}, {:.3f}, {:.4f}".format(mf1, wf1, validation_loss))
        print(" -- Finished in:   {:.3f} s".format(time.time() - start))
    
    results = np.array(results)
    m, s = np.mean(results, axis=0), np.std(results, axis=0)
    print("\n - {}, {} tests complete".format(NUM_SAMPLES, NUM_ITERS))
    print(" - Micro F1 Score mean[std]:    {:.4f}[{:.4f}]".format(m[0], s[0]))
    print(" - Weighted F1 Score mean[std]: {:.4f}[{:.4f}]".format(m[1], s[1]))
    result_dict[str(NUM_SAMPLES) + "," + str(NUM_ITERS) + "," + str(PERCENT)] = results


5 sample(s) of 10 iteration(s) of T-SVD at 1% corruption
 - Test 1
 -- MF1, WF1, BVL: 0.686, 0.662, 0.9834
 -- Finished in:   120.448 s
 - Test 2
 -- MF1, WF1, BVL: 0.663, 0.628, 1.0623
 -- Finished in:   124.228 s
 - Test 3
 -- MF1, WF1, BVL: 0.670, 0.634, 1.0224
 -- Finished in:   143.160 s
 - Test 4
 -- MF1, WF1, BVL: 0.672, 0.636, 1.0344
 -- Finished in:   117.874 s
 - Test 5
 -- MF1, WF1, BVL: 0.672, 0.632, 1.0826
 -- Finished in:   118.623 s
 - Test 6
 -- MF1, WF1, BVL: 0.676, 0.653, 1.0801
 -- Finished in:   119.749 s
 - Test 7
 -- MF1, WF1, BVL: 0.663, 0.608, 1.0547
 -- Finished in:   119.625 s
 - Test 8
 -- MF1, WF1, BVL: 0.648, 0.594, 1.0862
 -- Finished in:   120.307 s
 - Test 9
 -- MF1, WF1, BVL: 0.663, 0.633, 1.0604
 -- Finished in:   120.228 s
 - Test 10
 -- MF1, WF1, BVL: 0.659, 0.617, 1.0269
 -- Finished in:   120.319 s
 - Test 11
 -- MF1, WF1, BVL: 0.638, 0.585, 1.1132
 -- Finished in:   120.508 s
 - Test 12
 -- MF1, WF1, BVL: 0.676, 0.643, 1.0316
 -- Finished in:   1

In [27]:
result_data = {}
for k, v in result_dict.items():
    result_data[k] = v.tolist()

In [28]:
import json

with open(DATASET + '__svd_5_10_perc.json', 'w') as fp:
    json.dump(result_data, fp)