In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.nn import Linear, Parameter
from torch_geometric.datasets import WebKB, Planetoid, WikipediaNetwork
from torch_geometric.nn import GCNConv, VGAE
#from torch_geometric.utils import train_test_split_edges
from torch_geometric.transforms import RandomLinkSplit
from torch_geometric.utils import negative_sampling
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
# import torch_geometric.transforms as T
# transform = T.Compose([
#     T.RandomNodeSplit(num_val=500, num_test=500),
#     T.TargetIndegree(),
# ])

# dataset = Planetoid(root="data", name="CiteSeer", transform=transform)
dataset = WebKB(root="/home/siddy/META/data", name="cornell")
data = dataset[0].to(device)

In [35]:
import wandb
run_vgae_rewire = wandb.init(
        project = "VAE-Experiments-regularizer",
        config = {
            "architecture": "VGAE+Custom_loss+Distributed_backprop",
            "model":"gcn",
            "dataset":"cornell",
            "epoch": 500,
            "lr": 0.01,
            "weight_decay":1e-3,
            "Batch size": 1,
            "mask":1,
            "alpha":1,
            "beta":1,
        }
    )

In [36]:
data

Data(x=[183, 1703], edge_index=[2, 298], y=[183], train_mask=[183, 10], val_mask=[183, 10], test_mask=[183, 10])

In [37]:
edge_index= data.edge_index
num_nodes = data.x.size(0)
adj_matrix = torch.zeros((num_nodes, num_nodes), dtype=torch.long)
for i in range(edge_index.size(1)):
    u, v = int(edge_index[0, i]), int(edge_index[1, i])
    adj_matrix[u, v] = 1
    adj_matrix[v, u] = 1

In [38]:
adj_squared = torch.mm(adj_matrix, adj_matrix)

In [39]:
def neg_index(edge_index, adj_matrix, adj_squared):
    neg_indices = []
    for i in range(edge_index.size(1)):
        u, v = int(edge_index[0, i]), int(edge_index[1, i])
        #adj_squared = torch.mm(adj_matrix, adj_matrix)
        adj_row_u = adj_matrix[u]
        adj_row_v = adj_matrix[v]
        bitwise_and_result = torch.bitwise_and(adj_squared[u], adj_row_v[v])
        sum_bitwise_and_result = torch.sum(bitwise_and_result)
        sum_adj_u = torch.sum(adj_row_u)
        sum_adj_v = torch.sum(adj_row_v)
        neg_index = sum_bitwise_and_result.float() / (sum_adj_u + sum_adj_v)
        neg_indices.append(neg_index.item())

    return neg_indices

In [40]:
def pos_index(edge_index, adj_matrix, adj_squared):
    pos_indices = []
    for i in range(edge_index.size(1)):
        u, v = int(edge_index[0, i]), int(edge_index[1, i])
        adj_row_u = adj_matrix[u]
        adj_row_v = adj_matrix[v]
        bitwise_and_result = torch.bitwise_and(adj_row_u, adj_row_v)
        sum_bitwise_and_result = torch.sum(bitwise_and_result)
        sum_adj_u = torch.sum(adj_row_u)
        sum_adj_v = torch.sum(adj_row_v)
        pos_index = sum_bitwise_and_result.float() / (sum_adj_u + sum_adj_v)
        pos_indices.append(pos_index.item())

    return pos_indices

In [41]:
class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='add')
        self.lin = Linear(in_channels, out_channels, bias=False)
        self.bias = Parameter(torch.empty(out_channels))

        self.reset_parameters()

    def reset_parameters(self):
        self.lin.reset_parameters()
        self.bias.data.zero_()

    def forward(self, x, edge_index):
        # what is the shape of inpur x ? - needed [N, in_channels]
        # edge indices shape needed is [2, E]

        #add self_loops to the adjacency matrix, how to give num nodes?
        #edge_index, _ = add_self_loops(edge_index)
        #print(edge_index)
        # linearly transform node feature matrix
        x = self.lin(x)
        #x = torch.index_select(input=x, index=edge_index[0], dim=0)
        # x_ball = torch.cat([torch.index_select(input=x, index=edge_index[0], dim=0), NOTE THAT IT WILL GIVE INDEX OUT OF RANGE ONE OPTION IS TO GO WITH REINDEXING
        #             torch.index_select(input=x, index=edge_index[1], dim=0)],dim=0)
        #compute normalization
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(0.5)
        deg_inv_sqrt[deg_inv_sqrt==float('inf')] = 0
        #print(deg_inv_sqrt.shape)
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # propagating messages
        out = self.propagate(edge_index, x=x, norm=norm)
        #out = torch.index_select(input=out, index=min(edge_index[0]), dim=0) #NOTE TRICK IS TO PICK MIN EDGE INDEX AS IT WILL CORRESPOND TO THE CENTER NODE OF THE BALL
        # bias
        out += self.bias
        return torch.squeeze(out)

    def message(self, x_j, norm):
        # x_j has shape [E, out_channels]
        # normalize node features
        return norm.view(-1,1) *x_j

In [42]:
class GCNNet(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv = GCNConv(in_channels, hidden_channels)
        self.fc = Linear(hidden_channels, out_channels)
    def forward(self, x, edge_index):
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.fc(self.conv(x, edge_index))
        return x

In [43]:
alpha, beta = 1.0, 1.0
def custom_objective(sigma_1, sigma_2):
  sigma1 = alpha*((sigma_1-0.5)**2) 
  sigma2 = beta*(sigma_2)
  #wandb.log({"neg_indices":sigma1})
  #wandb.log({"pos_indices":sigma2})
  return sigma1+sigma2

In [44]:
# Encoder
class GCNEncoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(GCNEncoder, self).__init__()
        self.conv1 = GCNConv(in_channels, 2 * out_channels)
        self.conv_mu = GCNConv(2 * out_channels, out_channels)
        self.conv_logstd = GCNConv(2 * out_channels, out_channels)

    def forward(self,x ,edge_index):
        x = self.conv1(x, edge_index).relu()
        return self.conv_mu(x, edge_index), self.conv_logstd(x, edge_index)

In [45]:
out_channels= dataset.num_classes
num_features = dataset.num_features
epochs=500
num_nodes_train = data.x.size(0)
neg_indices = neg_index(data.edge_index, adj_matrix, adj_squared)
pos_indices = pos_index(data.edge_index, adj_matrix, adj_squared)

In [46]:
model = VGAE(GCNEncoder(num_features, 16))
gcn_net = GCNNet(in_channels=data.x.size(1), hidden_channels=64, out_channels=dataset.num_classes)
gcn_net = gcn_net.to(device)
model = model.to(device)
x = data.x.to(device)
optimizer1 = torch.optim.Adam(model.parameters(), lr=0.1, weight_decay=0.001)
optimizer2 = torch.optim.Adam(gcn_net.parameters(), lr=0.01, weight_decay=1e-3)

Exception in thread SystemMonitor:
Traceback (most recent call last):
  File "/home/siddy/anaconda3/envs/GDL/lib/python3.8/threading.py", line 932, in _bootstrap_inner
    self.run()
  File "/home/siddy/anaconda3/envs/GDL/lib/python3.8/threading.py", line 870, in run
    self._target(*self._args, **self._kwargs)
  File "/home/siddy/.local/lib/python3.8/site-packages/wandb/sdk/internal/system/system_monitor.py", line 118, in _start
    asset.start()
  File "/home/siddy/.local/lib/python3.8/site-packages/wandb/sdk/internal/system/assets/cpu.py", line 166, in start
    self.metrics_monitor.start()
  File "/home/siddy/.local/lib/python3.8/site-packages/wandb/sdk/internal/system/assets/interfaces.py", line 168, in start
    logger.info(f"Started {self._process.name}")
AttributeError: 'NoneType' object has no attribute 'name'


In [47]:
def train(train_data, neg_indices, pos_indices):
    model.train()
    gcn_net.train()
    optimizer1.zero_grad()
    
    # neg_edge_index = negative_sampling(
    #     edge_index= data.edge_index,
    #     num_nodes= data.x.size(0),
    #     num_neg_samples= data.edge_index.size(1)
    # )
    # neg_indices = neg_index(num_nodes_train, train_data.edge_index)
    # pos_indices = pos_index(num_nodes_train, train_data.edge_index)
    z = model.encode(x, data.edge_index)
    #print(f"latent space shape: {z.shape}")
    #adj = torch.sigmoid(torch.matmul(z, z.t()))
    #print(f"adj matrix shape: {adj.shape}")
    loss = model.recon_loss(z, train_data.edge_index)
    wandb.log({"recon_loss":loss.item()})
    kl_loss = (1 / data.num_nodes) * model.kl_loss()  # new line
    wandb.log({"kl_loss":kl_loss.item()})
    loss += kl_loss
    #adj_binary = (adj > 0.5).float()
    #edge_list = adj.nonzero(as_tuple=False)
    #edge_list = torch.permute(torch.tensor(edge_list, dtype=torch.long), (1,0)).to(device)
    # neg_indices = neg_index(num_nodes_train, train_data.edge_list)
    # pos_indices = pos_index(num_nodes_train, train_data.edge_list)
    for i in range(len(neg_indices)):
        sigma_1 = torch.tensor(neg_indices[i], requires_grad=True)
        sigma_2 = torch.tensor(pos_indices[i], requires_grad=True)
        objective = custom_objective(sigma_1, sigma_2)
        wandb.log({"obj_loss":objective.item()})
        loss += objective
    optimizer2.zero_grad()
    out = gcn_net(data.x, data.edge_index)
    out_2 = gcn_net(data.x, data.edge_index)
    nc_loss = F.cross_entropy(out[data.train_mask[:,1]], data.y[data.train_mask[:,1]])
    nc_loss_2 = F.cross_entropy(out_2[data.train_mask[:,1]], data.y[data.train_mask[:,1]])
    wandb.log({"nc_loss":nc_loss.item()})
    loss += nc_loss
    loss.backward()
    optimizer1.step()
    nc_loss_2.backward()
    optimizer2.step()
    return nc_loss_2, loss


def test(test_data):
    model.eval()
    gcn_net.eval()
    with torch.no_grad():
        # test_neg_edge_index = negative_sampling(
        #     edge_index= test_data.edge_index,
        #     num_nodes= test_data.x.size(0),
        #     num_neg_samples=test_data.edge_index.size(1)
        # )
        # #z = model.encode(x, test_data.edge_index)
        out = gcn_net(data.x, data.edge_index).argmax(dim=-1)
        accs=[]
        for mask in [data.train_mask[:,1], data.val_mask[:,1], data.test_mask[:,1]]:
            accs.append(int((out[mask] == data.y[mask]).sum())/ int(mask.sum()))
        return accs
    
    #return model.test(z, test_data.edge_index, test_neg_edge_index)


In [17]:
from torchviz import make_dot

In [18]:
nc_loss_2, _ = train(data, neg_indices, pos_indices)
nc_dot= make_dot(nc_loss_2, params=dict(gcn_net.named_parameters()))
nc_dot.render("Nc task computational graph", format="png")

'Nc task computational graph.png'

In [19]:
_, loss = train(data, neg_indices, pos_indices)
vgae_dot= make_dot(loss, params=dict(model.named_parameters()))
vgae_dot.render("VGAE task computational graph", format="png")

dot: graph is too large for cairo-renderer bitmaps. Scaling by 0.38154 to fit


'VGAE task computational graph.png'

In [48]:
import time
epochs=500
best_val_acc = final_test_acc = 0
times = []
for epoch in range(1, epochs+1):
    start = time.time()
    _, loss= train(data, neg_indices, pos_indices)
    wandb.log({"total_loss":loss})
    train_acc, val_acc, tmp_test_acc = test(data)
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        test_acc = tmp_test_acc
    print(f"Epoch:{epoch}, Loss:{loss}, Train:{train_acc}, Val:{val_acc}, Test:{test_acc}")
    wandb.log({"train_acc":train_acc})
    wandb.log({"test_acc":test_acc})
    wandb.log({"val_acc":val_acc})

Epoch:1, Loss:670.5535278320312, Train:0.42528735632183906, Val:0.5084745762711864, Test:0.43243243243243246
Epoch:2, Loss:888056.25, Train:0.5632183908045977, Val:0.4915254237288136, Test:0.43243243243243246
Epoch:3, Loss:402381.5, Train:0.42528735632183906, Val:0.3220338983050847, Test:0.43243243243243246
Epoch:4, Loss:2319121.5, Train:0.47126436781609193, Val:0.3220338983050847, Test:0.43243243243243246
Epoch:5, Loss:841112.875, Train:0.5862068965517241, Val:0.423728813559322, Test:0.43243243243243246
Epoch:6, Loss:2020.6790771484375, Train:0.6781609195402298, Val:0.5084745762711864, Test:0.43243243243243246
Epoch:7, Loss:424613.53125, Train:0.6781609195402298, Val:0.5254237288135594, Test:0.5135135135135135
Epoch:8, Loss:2044.8123779296875, Train:0.6781609195402298, Val:0.5254237288135594, Test:0.5135135135135135
Epoch:9, Loss:1673.47021484375, Train:0.6896551724137931, Val:0.5254237288135594, Test:0.5135135135135135
Epoch:10, Loss:2423.095458984375, Train:0.6896551724137931, Val:0

In [49]:
run_vgae_rewire.finish()

0,1
kl_loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
nc_loss,█▃▃▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
obj_loss,▁▁▂▁▅▆▅▁▃▁▁▁▃▅▁▁▁▃▅▃▁▁█▃▁▁█▁▁▁▃█▁▁▁▃▁▁▇▄
recon_loss,█▇▅▄▄▃▂▃▂▅▃▂▃▂▃▂▂▂▃▂▃▅▃▃▆▃▆▅█▅▅▄▃▂▁▂▁▁▂▁
test_acc,▁███████████████████████████████████████
total_loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_acc,▁██▇████████████████████████████████████
val_acc,▁▆▇▅▆▆▆▅▆▆▅▆▆▅▅▅▇▅▆▆▇▆▆▆▆█▇▇▇▆▆▇▆▅▆▇▅▆▅▇

0,1
kl_loss,357.86496
nc_loss,0.72592
obj_loss,0.25
recon_loss,7.64064
test_acc,0.51351
total_loss,986.92633
train_acc,0.68966
val_acc,0.49153
