In [4]:
import os

from torch.optim import Adam
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch_geometric.nn.models import InnerProductDecoder, VGAE
from torch_geometric.nn.conv import GCNConv
from torch_geometric.utils import negative_sampling, remove_self_loops, add_self_loops
import torch_geometric.transforms as T
from torch_geometric.utils import train_test_split_edges

from data_generator import ssp_data
from config.config import parse_args

import pickle

In [11]:
class GCNEncoder(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GCNEncoder, self).__init__()
        self.gcn_shared = GCNConv(in_channels, hidden_channels)
        self.gcn_mu = GCNConv(hidden_channels, out_channels)
        self.gcn_logvar = GCNConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = F.relu(self.gcn_shared(x, edge_index))
        mu = self.gcn_mu(x, edge_index)
        logvar = self.gcn_logvar(x, edge_index)
        return mu, logvar

class DeepVGAE(VGAE):
    def __init__(self):
        super(DeepVGAE, self).__init__(encoder=GCNEncoder(74100,
                                                          32,
                                                          41),
                                       decoder=InnerProductDecoder())

    def forward(self, x, edge_index):
        z = self.encode(x, edge_index)
        adj_pred = self.decoder.forward_all(z)
        return adj_pred

    def loss(self, x, pos_edge_index, all_edge_index):
        z = self.encode(x, pos_edge_index)

        pos_loss = -torch.log(
            self.decoder(z, pos_edge_index, sigmoid=True) + 1e-15).mean()

        # Do not include self-loops in negative samples
        all_edge_index_tmp, _ = remove_self_loops(all_edge_index)
        all_edge_index_tmp, _ = add_self_loops(all_edge_index_tmp)

        neg_edge_index = negative_sampling(all_edge_index_tmp, z.size(0), pos_edge_index.size(1))
        neg_loss = -torch.log(1 - self.decoder(z, neg_edge_index, sigmoid=True) + 1e-15).mean()

        kl_loss = 1 / x.size(0) * self.kl_loss()

        return pos_loss + neg_loss + kl_loss

    def single_test(self, x, train_pos_edge_index, test_pos_edge_index, test_neg_edge_index):
        with torch.no_grad():
            z = self.encode(x, train_pos_edge_index)
        roc_auc_score, average_precision_score = self.test(z, test_pos_edge_index, test_neg_edge_index)
        return roc_auc_score, average_precision_score

In [12]:
file = open('data_pickle', 'rb')
ssp_obj = pickle.load(file)

In [13]:
ssp_obj.train_data

Data(x=[41, 74100], edge_index=[2, 1482], y=[74100])

In [15]:
torch.manual_seed(12345)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = DeepVGAE().to(device)
optimizer = Adam(model.parameters(), lr=0.01)

data = ssp_obj.train_data
all_edge_index = data.edge_index
data = train_test_split_edges(data, 0.05, 0.1)



In [16]:
for epoch in range(4000):
    model.train()
    optimizer.zero_grad()
    loss = model.loss(data.x, data.train_pos_edge_index, all_edge_index)
    loss.backward()
    optimizer.step()
    if epoch % 2 == 0:
        model.eval()
        roc_auc, ap = model.single_test(data.x,
                                        data.train_pos_edge_index,
                                        data.test_pos_edge_index,
                                        data.test_neg_edge_index)
        print("Epoch {} - Loss: {} ROC_AUC: {} Precision: {}".format(epoch, loss.cpu().item(), roc_auc, ap))

Epoch 0 - Loss: 106515632.0 ROC_AUC: 0.5 Precision: 0.6379310344827587
Epoch 2 - Loss: 5302139392.0 ROC_AUC: 0.5 Precision: 0.6379310344827587
Epoch 4 - Loss: 5057438720.0 ROC_AUC: 0.5 Precision: 0.6379310344827587
Epoch 6 - Loss: 3018239232.0 ROC_AUC: 0.5 Precision: 0.6379310344827587
Epoch 8 - Loss: 1682616576.0 ROC_AUC: 0.5 Precision: 0.6379310344827587
Epoch 10 - Loss: 1009156544.0 ROC_AUC: 0.5 Precision: 0.6379310344827587
Epoch 12 - Loss: 794441152.0 ROC_AUC: 0.5 Precision: 0.6379310344827587
Epoch 14 - Loss: 681479808.0 ROC_AUC: 0.5 Precision: 0.6379310344827587
Epoch 16 - Loss: 574769728.0 ROC_AUC: 0.5 Precision: 0.6379310344827587
Epoch 18 - Loss: 495180000.0 ROC_AUC: 0.5 Precision: 0.6379310344827587
Epoch 20 - Loss: 446272512.0 ROC_AUC: 0.5 Precision: 0.6379310344827587
Epoch 22 - Loss: 401086944.0 ROC_AUC: 0.5 Precision: 0.6379310344827587
Epoch 24 - Loss: 360719200.0 ROC_AUC: 0.5 Precision: 0.6379310344827587
Epoch 26 - Loss: 329076992.0 ROC_AUC: 0.5 Precision: 0.637931034