In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.nn import Linear, Parameter
from torch_geometric.datasets import WebKB, Planetoid, WikipediaNetwork
from torch_geometric.nn import GCNConv, VGAE
#from torch_geometric.utils import train_test_split_edges
from torch_geometric.transforms import RandomLinkSplit
from torch_geometric.utils import negative_sampling
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree



In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
dataset = WebKB(root="/home/siddy/META/data", name="Cornell")
data = dataset[0].to(device)

In [4]:
data

Data(x=[183, 1703], edge_index=[2, 298], y=[183], train_mask=[183, 10], val_mask=[183, 10], test_mask=[183, 10])

In [5]:
edge_index= data.edge_index

In [6]:
transform = RandomLinkSplit(is_undirected=True, add_negative_train_samples=False)

In [7]:
train_data, val_data, test_data = transform(data)

In [8]:
train_data

Data(x=[183, 1703], edge_index=[2, 236], y=[183], train_mask=[183, 10], val_mask=[183, 10], test_mask=[183, 10], edge_label=[118], edge_label_index=[2, 118])

In [9]:
val_data

Data(x=[183, 1703], edge_index=[2, 236], y=[183], train_mask=[183, 10], val_mask=[183, 10], test_mask=[183, 10], edge_label=[32], edge_label_index=[2, 32])

In [10]:
test_data

Data(x=[183, 1703], edge_index=[2, 268], y=[183], train_mask=[183, 10], val_mask=[183, 10], test_mask=[183, 10], edge_label=[66], edge_label_index=[2, 66])

In [11]:
num_nodes_train = train_data.x.size(0)

In [12]:
def neg_index(num_nodes, edge_index):
    adj_matrix = torch.zeros((num_nodes, num_nodes), dtype=torch.long)
    for i in range(edge_index.size(1)):
        u, v = int(edge_index[0, i]), int(edge_index[1, i])
        adj_matrix[u, v] = 1
        adj_matrix[v, u] = 1

    neg_indices = []
    for i in range(edge_index.size(1)):
        u, v = int(edge_index[0, i]), int(edge_index[1, i])
        adj_squared = torch.mm(adj_matrix, adj_matrix)
        adj_row_u = adj_matrix[u]
        adj_row_v = adj_matrix[v]
        bitwise_and_result = torch.bitwise_and(adj_squared[u], adj_squared[v])
        sum_bitwise_and_result = torch.sum(bitwise_and_result)
        sum_adj_u = torch.sum(adj_row_u)
        sum_adj_v = torch.sum(adj_row_v)
        neg_index = sum_bitwise_and_result.float() / (sum_adj_u + sum_adj_v)
        neg_indices.append(neg_index.item())

    return neg_indices

In [13]:
def pos_index(num_nodes, edge_index):
    adj_matrix = torch.zeros((num_nodes, num_nodes), dtype=torch.long)
    for i in range(edge_index.size(1)):
        u, v = int(edge_index[0, i]), int(edge_index[1, i])
        adj_matrix[u, v] = 1
        adj_matrix[v, u] = 1

    pos_indices = []
    for i in range(edge_index.size(1)):
        u, v = int(edge_index[0, i]), int(edge_index[1, i])
        adj_row_u = adj_matrix[u]
        adj_row_v = adj_matrix[v]
        bitwise_and_result = torch.bitwise_and(adj_row_u, adj_row_v)
        sum_bitwise_and_result = torch.sum(bitwise_and_result)
        sum_adj_u = torch.sum(adj_row_u)
        sum_adj_v = torch.sum(adj_row_v)
        pos_index = sum_bitwise_and_result.float() / (sum_adj_u + sum_adj_v)
        pos_indices.append(pos_index.item())

    return pos_indices

In [14]:
class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='add')
        self.lin = Linear(in_channels, out_channels, bias=False)
        self.bias = Parameter(torch.empty(out_channels))

        self.reset_parameters()

    def reset_parameters(self):
        self.lin.reset_parameters()
        self.bias.data.zero_()

    def forward(self, x, edge_index):
        # what is the shape of inpur x ? - needed [N, in_channels]
        # edge indices shape needed is [2, E]

        #add self_loops to the adjacency matrix, how to give num nodes?
        #edge_index, _ = add_self_loops(edge_index)
        #print(edge_index)
        # linearly transform node feature matrix
        x = self.lin(x)
        #x = torch.index_select(input=x, index=edge_index[0], dim=0)
        # x_ball = torch.cat([torch.index_select(input=x, index=edge_index[0], dim=0), NOTE THAT IT WILL GIVE INDEX OUT OF RANGE ONE OPTION IS TO GO WITH REINDEXING
        #             torch.index_select(input=x, index=edge_index[1], dim=0)],dim=0)
        #compute normalization
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(0.5)
        deg_inv_sqrt[deg_inv_sqrt==float('inf')] = 0
        #print(deg_inv_sqrt.shape)
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # propagating messages
        out = self.propagate(edge_index, x=x, norm=norm)
        #out = torch.index_select(input=out, index=min(edge_index[0]), dim=0) #NOTE TRICK IS TO PICK MIN EDGE INDEX AS IT WILL CORRESPOND TO THE CENTER NODE OF THE BALL
        # bias
        out += self.bias
        return torch.squeeze(out)

    def message(self, x_j, norm):
        # x_j has shape [E, out_channels]
        # normalize node features
        return norm.view(-1,1) *x_j

In [15]:
class GCNNet(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv = GCNConv(in_channels, hidden_channels)
        self.fc = Linear(hidden_channels, out_channels)
    def forward(self, x, edge_index):
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.fc(self.conv(x, edge_index))
        return x

In [42]:
alpha, beta = 0.005, 0.005
def custom_objective(sigma_1, sigma_2):
    return (alpha*sigma_1 - 0.5)**2 + beta*sigma_2

In [17]:
# Encoder
class GCNEncoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(GCNEncoder, self).__init__()
        self.conv1 = GCNConv(in_channels, 2 * out_channels)
        self.conv_mu = GCNConv(2 * out_channels, out_channels)
        self.conv_logstd = GCNConv(2 * out_channels, out_channels)

    def forward(self,x ,edge_index):
        x = self.conv1(x, edge_index).relu()
        return self.conv_mu(x, edge_index), self.conv_logstd(x, edge_index)

In [18]:
out_channels= dataset.num_classes
num_features = dataset.num_features
epochs=100
neg_indices = neg_index(num_nodes_train, train_data.edge_index)
pos_indices = pos_index(num_nodes_train, train_data.edge_index)

In [19]:
len(neg_indices)

236

In [22]:
obj = []
for i in range(train_data.edge_index.size(1)):
        #u, v = int(data.edge_index[0, i]), int(data.edge_index[1, i])
        #current_pair
    sigma_1= torch.tensor(neg_indices[i], requires_grad=True)
    sigma_2 = torch.tensor(pos_indices[i], requires_grad=True)
    objective = custom_objective(sigma_1, sigma_2)
    obj.append(objective)

In [20]:
range(train_data.edge_index.size(1))

range(0, 236)

In [28]:
len(pos_indices)

236

In [21]:
model = VGAE(GCNEncoder(num_features, out_channels))
gcn_net = GCNNet(in_channels=data.x.size(1), hidden_channels=64, out_channels=dataset.num_classes)
gcn_net = gcn_net.to(device)
model = model.to(device)
x = data.x.to(device)
optimizer = torch.optim.Adam(list(model.parameters())+ list(gcn_net.parameters()), lr=0.001, weight_decay=5e-4)

In [49]:
def train(train_data, neg_indices, pos_indices):
    model.train()
    optimizer.zero_grad()
    neg_edge_index = negative_sampling(
        edge_index= train_data.edge_index,
        num_nodes= train_data.x.size(0),
        num_neg_samples=train_data.edge_index.size(1)
    )
    # neg_indices = neg_index(num_nodes_train, train_data.edge_index)
    # pos_indices = pos_index(num_nodes_train, train_data.edge_index)
    z = model.encode(x, train_data.edge_index)
    #print(f"latent space shape: {z.shape}")
    #adj = torch.sigmoid(torch.matmul(z, z.t()))
    #print(f"adj matrix shape: {adj.shape}")
    #loss = model.recon_loss(z, train_data.edge_index)
    loss = (1 / data.num_nodes) * model.kl_loss()  # new line
    #adj_binary = (adj > 0.5).float()
    #edge_list = adj.nonzero(as_tuple=False)
    #edge_list = torch.permute(torch.tensor(edge_list, dtype=torch.long), (1,0)).to(device)
    # neg_indices = neg_index(num_nodes_train, train_data.edge_list)
    # pos_indices = pos_index(num_nodes_train, train_data.edge_list)
    for i in range(len(neg_indices)):
        sigma_1 = torch.tensor(neg_indices[i], requires_grad=True)
        sigma_2 = torch.tensor(pos_indices[i], requires_grad=True)
        objective = custom_objective(sigma_1, sigma_2)
        loss += objective
    out = gcn_net(train_data.x, train_data.edge_index)
    nc_loss = F.cross_entropy(out[train_data.train_mask[:,0]], train_data.y[train_data.train_mask[:,0]])
    loss += nc_loss
    loss.backward()
    optimizer.step()
    return float(loss)


def test(test_data):
    model.eval()
    gcn_net.eval()
    with torch.no_grad():
        test_neg_edge_index = negative_sampling(
            edge_index= test_data.edge_index,
            num_nodes= test_data.x.size(0),
            num_neg_samples=test_data.edge_index.size(1)
        )
        z = model.encode(x, test_data.edge_index)
        out = 
        accs=[]
        accs.append(int())
    
    return model.test(z, test_data.edge_index, test_neg_edge_index)

In [50]:
num_nodes_train = train_data.x.size(0)

In [51]:
for epoch in range(1, epochs + 1):
    loss = train(train_data, neg_indices, pos_indices)
    print(loss)
    auc, ap = test(test_data)
    print('Epoch: {:03d}, AUC: {:.4f}, AP: {:.4f}'.format(epoch, auc, ap))

59.67121505737305
Epoch: 001, AUC: 0.5130, AP: 0.5084
59.56915283203125
Epoch: 002, AUC: 0.5111, AP: 0.5075
59.63116455078125
Epoch: 003, AUC: 0.5148, AP: 0.5093
59.61537170410156
Epoch: 004, AUC: 0.5056, AP: 0.5047
59.570648193359375
Epoch: 005, AUC: 0.5185, AP: 0.5113
59.53494644165039
Epoch: 006, AUC: 0.5148, AP: 0.5093
59.57355499267578
Epoch: 007, AUC: 0.5074, AP: 0.5056
59.575401306152344
Epoch: 008, AUC: 0.5074, AP: 0.5056
59.55477523803711
Epoch: 009, AUC: 0.5037, AP: 0.5037
59.565792083740234
Epoch: 010, AUC: 0.5093, AP: 0.5065
59.5667839050293
Epoch: 011, AUC: 0.5093, AP: 0.5065
59.614925384521484
Epoch: 012, AUC: 0.5185, AP: 0.5113
59.58082962036133
Epoch: 013, AUC: 0.5074, AP: 0.5056
59.59929656982422
Epoch: 014, AUC: 0.5093, AP: 0.5065
59.60763168334961
Epoch: 015, AUC: 0.5037, AP: 0.5037
59.61659622192383
Epoch: 016, AUC: 0.5074, AP: 0.5056
59.58620834350586
Epoch: 017, AUC: 0.5093, AP: 0.5065
59.56905746459961
Epoch: 018, AUC: 0.5074, AP: 0.5056
59.58341979980469
Epoch: 