In [72]:
import torch
import torch.nn as nn
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import Planetoid, Reddit
from torch import Tensor

In [73]:
'''
class Discriminator(nn.Module):
    def __init__(self, h, summary, dim):
        super(Discriminator, self).__init__()
        self.h = h
        self.summary = summary
        self.weight = nn.Parameter(torch.Tensor(dim, dim))
        nn.init.xavier_uniform_(self.weight)
        
    def forward(self):
        value = torch.matmul(h, torch.matmul(self.weight, summary))
        return torch.sigmoid(value)
'''

'\nclass Discriminator(nn.Module):\n    def __init__(self, h, summary, dim):\n        super(Discriminator, self).__init__()\n        self.h = h\n        self.summary = summary\n        self.weight = nn.Parameter(torch.Tensor(dim, dim))\n        nn.init.xavier_uniform_(self.weight)\n        \n    def forward(self):\n        value = torch.matmul(h, torch.matmul(self.weight, summary))\n        return torch.sigmoid(value)\n'

In [74]:
'''
class GCN(nn.Module):
    def __init__(self, ft_in, n_fts):
        super(GCN, self).__init__()
        self.conv = GCNConv(ft_in, n_fts)
        self.act = nn.PReLU(n_fts)
        
    def forward(self, x, edge_index):
        x = self.conv(x, edge_index)
        x = self.act(x)
        return x
'''

'\nclass GCN(nn.Module):\n    def __init__(self, ft_in, n_fts):\n        super(GCN, self).__init__()\n        self.conv = GCNConv(ft_in, n_fts)\n        self.act = nn.PReLU(n_fts)\n        \n    def forward(self, x, edge_index):\n        x = self.conv(x, edge_index)\n        x = self.act(x)\n        return x\n'

In [85]:
class GCNlayer(nn.Module):
    def __init__(self, in_ftr, out_ftr):
        super(GCNlayer, self).__init__()
        self.conv = GCNConv(in_ftr, out_ftr)
        self.activation = nn.PReLU(out_ftr)
        
    def forward(self, x, edge_index):
        x = self.conv(x, edge_index)
        x = self.activation(x)
        return x

In [86]:
class DGI(nn.Module):
    def __init__(self, data, dim):
        super(DGI, self).__init__()
        self.dim = dim
        self.data = Data(data.x, data.edge_index)
        self.loss = nn.BCEWithLogitsLoss()
        self.weight = nn.Parameter(torch.Tensor(self.dim, self.dim))
        nn.init.xavier_uniform_(self.weight)
        
    def discriminator(self, h, summary):
        value = torch.matmul(h, torch.matmul(self.weight, summary))
        return torch.sigmoid(value)
    
    def corruption(self, data):
        return Data(self.data.x[torch.randperm(self.data.x.size(0))], self.data.edge_index)
    
    def forward(self):
        pos_x = self.data
        neg_x = self.corruption(pos_x)
        encoder = GCNlayer(self.data.num_features, self.dim)
        pos_h = encoder(pos_x.x, pos_x.edge_index)
        neg_h = encoder(neg_x.x, neg_x.edge_index)
        summary = torch.sigmoid(torch.mean(pos_h, dim = 0))
        
        pos_h = self.discriminator(pos_h,summary)
        neg_h = self.discriminator(neg_h,summary)
        loss_pos = self.loss(pos_h, torch.ones_like(pos_h))
        loss_neg = self.loss(neg_h, torch.zeros_like(neg_h))
        return loss_pos + loss_neg
    
    def predict(self, data):
        pos_x = data
        neg_x = self.corruption(pos_x)
        encoder = GCNlayer(self.data.num_features, self.dim)
        pos_h = encoder(pos_x.x, pos_x.edge_index)
        neg_h = encoder(neg_x.x, neg_x.edge_index)
        summary = torch.sigmoid(torch.mean(pos_h, dim = 0))
        return pos_h, neg_h, summary
    
    def test(
        self,
        train_z: Tensor,
        train_y: Tensor,
        test_z: Tensor,
        test_y: Tensor,
        solver: str = 'lbfgs',
        multi_class: str = 'auto',
        *args,
        **kwargs,
    ) -> float:
        r"""Evaluates latent space quality via a logistic regression downstream
        task."""
        from sklearn.linear_model import LogisticRegression

        clf = LogisticRegression(solver=solver, multi_class=multi_class, *args,
                                 **kwargs).fit(train_z.detach().cpu().numpy(),
                                               train_y.detach().cpu().numpy())
        return clf.score(test_z.detach().cpu().numpy(),
                         test_y.detach().cpu().numpy())


In [87]:
dataset = Planetoid(root='/tmp/Cora', name = 'Cora')
data = dataset[0]

In [88]:
data

Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])

In [89]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = DGI(data, 512)
optimizer = torch.optim.Adam(model.parameters(), lr = 0.001)

In [90]:
def train():
    model.train()
    optimizer.zero_grad()
    loss = model()
    loss.backward()
    optimizer.step()
    return loss.item()

In [91]:
def test():
    model.eval()
    z, _, _ = model.predict(data)
    acc = model.test(z[data.train_mask],data.y[data.train_mask],z[data.test_mask],data.y[data.test_mask],max_iter=10)
    return acc

In [92]:
for epoch in range(30):
    loss = train()
    print("Epoch: {:d}, Loss: {:.4f}".format(epoch+1, loss))
acc = test()
print("Accuracy: {:.4f}".format(acc))

Epoch: 1, Loss: 1.4673
Epoch: 2, Loss: 1.4064
Epoch: 3, Loss: 1.3930
Epoch: 4, Loss: 1.3883
Epoch: 5, Loss: 1.3881
Epoch: 6, Loss: 1.3868
Epoch: 7, Loss: 1.3865
Epoch: 8, Loss: 1.3865
Epoch: 9, Loss: 1.3864
Epoch: 10, Loss: 1.3864
Epoch: 11, Loss: 1.3864
Epoch: 12, Loss: 1.3864
Epoch: 13, Loss: 1.3864
Epoch: 14, Loss: 1.3863
Epoch: 15, Loss: 1.3863
Epoch: 16, Loss: 1.3863
Epoch: 17, Loss: 1.3863
Epoch: 18, Loss: 1.3863
Epoch: 19, Loss: 1.3863
Epoch: 20, Loss: 1.3863
Epoch: 21, Loss: 1.3863
Epoch: 22, Loss: 1.3863
Epoch: 23, Loss: 1.3863
Epoch: 24, Loss: 1.3863
Epoch: 25, Loss: 1.3863
Epoch: 26, Loss: 1.3863
Epoch: 27, Loss: 1.3863
Epoch: 28, Loss: 1.3863
Epoch: 29, Loss: 1.3863
Epoch: 30, Loss: 1.3863
Accuracy: 0.7380


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
