In [1]:
import os

from torch.optim import Adam
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch_geometric.nn.models import InnerProductDecoder, VGAE
from torch_geometric.nn.conv import GCNConv
from torch_geometric.utils import negative_sampling, remove_self_loops, add_self_loops
import torch_geometric.transforms as T
from torch_geometric.utils import train_test_split_edges

from data_generator import ssp_data
from config.config import parse_args

import pickle

In [32]:
class GCNEncoder(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GCNEncoder, self).__init__()
        self.gcn_shared = GCNConv(in_channels, hidden_channels)
        self.gcn_mu = GCNConv(hidden_channels, out_channels)
        self.gcn_logvar = GCNConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = F.relu(self.gcn_shared(x, edge_index))
        mu = self.gcn_mu(x, edge_index)
        logvar = self.gcn_logvar(x, edge_index)
        return mu, logvar

class DeepVGAE(VGAE):
    def __init__(self):
        super(DeepVGAE, self).__init__(encoder=GCNEncoder(163800,
                                                          4,
                                                          2),
                                       decoder=InnerProductDecoder())

    def forward(self, x, edge_index):
        z = self.encode(x, edge_index)
        adj_pred = self.decoder.forward_all(z)
        return adj_pred

    def loss(self, x, pos_edge_index, all_edge_index):
        z = self.encode(x, pos_edge_index)

        pos_loss = -torch.log(
            self.decoder(z, pos_edge_index, sigmoid=True) + 1e-15).mean()

        # Do not include self-loops in negative samples
        all_edge_index_tmp, _ = remove_self_loops(all_edge_index)
        all_edge_index_tmp, _ = add_self_loops(all_edge_index_tmp)

        # neg_edge_index = negative_sampling(all_edge_index_tmp, z.size(0), pos_edge_index.size(1))
        # neg_loss = -torch.log(1 - self.decoder(z, neg_edge_index, sigmoid=True) + 1e-15).mean()

        kl_loss = 1 / x.size(0) * self.kl_loss()

        return pos_loss + kl_loss

    def single_test(self, x, train_pos_edge_index, test_pos_edge_index, test_neg_edge_index):
        with torch.no_grad():
            z = self.encode(x, train_pos_edge_index)
        roc_auc_score, average_precision_score = self.test(z, test_pos_edge_index, test_neg_edge_index)
        return roc_auc_score, average_precision_score

In [33]:
file = open('data_pickle', 'rb')
ssp_obj = pickle.load(file)

In [34]:
ssp_obj.test_data

Data(x=[39, 163800], edge_index=[2, 1482], y=[163800])

In [35]:
torch.manual_seed(12345)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = DeepVGAE().to(device)
optimizer = Adam(model.parameters(), lr=0.01)

data = ssp_obj.test_data
all_edge_index = data.edge_index
data = train_test_split_edges(data, 0.05, 0.1)



In [36]:
for epoch in range(4000):
    model.train()
    optimizer.zero_grad()
    loss = model.loss(data.x, data.train_pos_edge_index, all_edge_index)
    loss.backward()
    optimizer.step()
    if epoch % 2 == 0:
        model.eval()
        # roc_auc, ap = model.single_test(data.x,
        #                                 data.train_pos_edge_index,
        #                                 data.test_pos_edge_index,
        #                                 data.test_neg_edge_index)
        print("Epoch {} - Loss: {}".format(epoch, loss.cpu().item()))

Epoch 0 - Loss: 6220097.0
Epoch 2 - Loss: 357617856.0
Epoch 4 - Loss: 0.8234779834747314
Epoch 6 - Loss: 0.7873738408088684
Epoch 8 - Loss: 0.8217275142669678
Epoch 10 - Loss: 0.9174346923828125
Epoch 12 - Loss: 0.9032240509986877
Epoch 14 - Loss: 0.8045551776885986
Epoch 16 - Loss: 0.8962275385856628
Epoch 18 - Loss: 0.8130490183830261
Epoch 20 - Loss: 0.7804829478263855
Epoch 22 - Loss: 0.8159571290016174
Epoch 24 - Loss: 0.8280089497566223
Epoch 26 - Loss: 0.8856480717658997
Epoch 28 - Loss: 0.7686632871627808
Epoch 30 - Loss: 0.810041069984436
Epoch 32 - Loss: 0.8103824853897095
Epoch 34 - Loss: 0.8117279410362244
Epoch 36 - Loss: 0.7547846436500549
Epoch 38 - Loss: 0.7436257600784302
Epoch 40 - Loss: 0.7151547074317932
Epoch 42 - Loss: 0.7463626861572266
Epoch 44 - Loss: 0.7859660387039185
Epoch 46 - Loss: 0.7308390736579895
Epoch 48 - Loss: 0.7942070364952087
Epoch 50 - Loss: 0.7399740219116211
Epoch 52 - Loss: 0.771436333656311
Epoch 54 - Loss: 0.7557837963104248
Epoch 56 - Loss

In [38]:
dir(model)

['T_destination',
 '__annotations__',
 '__call__',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattr__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__logstd__',
 '__lt__',
 '__module__',
 '__mu__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__setstate__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_apply',
 '_backward_hooks',
 '_buffers',
 '_call_impl',
 '_forward_hooks',
 '_forward_pre_hooks',
 '_get_backward_hooks',
 '_get_name',
 '_is_full_backward_hook',
 '_load_from_state_dict',
 '_load_state_dict_post_hooks',
 '_load_state_dict_pre_hooks',
 '_maybe_warn_non_full_backward_hook',
 '_modules',
 '_named_members',
 '_non_persistent_buffers_set',
 '_parameters',
 '_register_load_state_dict_pre_hook',
 '_register_state_dict_hook',
 '_replicate_for_data_parallel',
 '_save_to_state_dict',
 '_slow_forward',
 '