In [1]:
import pandas as pd
import networkx as nx
import os.path as osp

import torch
import torch_geometric
from torch_geometric.data import Dataset, download_url
from torch_geometric.utils.convert import from_networkx
import numpy as np


import torch.nn.functional as F
from torch_geometric.nn import GCNConv,Linear
from torch_geometric.nn import GAE, Node2Vec,VGAE
from torch.utils.tensorboard import SummaryWriter
from torch_geometric.nn.models.autoencoder import ARGVA

from torch_geometric.transforms import RandomLinkSplit
from tqdm import tqdm

In [2]:
torch.manual_seed(0)

<torch._C.Generator at 0x297cd6fb5b0>

In [3]:
# import dataset
df=pd.read_csv('PPI.csv')
G=nx.from_pandas_edgelist(df,'Official Symbol Interactor A','Official Symbol Interactor B' )
#Gcc = sorted(nx.connected_components(G), key=len, reverse=True)
#G0 = G.subgraph(Gcc[0])
G=nx.convert_node_labels_to_integers(G)
pyg_graph = from_networkx(G)

node_embedding=Node2Vec(pyg_graph.edge_index,128,16,10)

#embedding del nodo 0
#node_embedding.forward().data[0]# cambia ogni volta anche con il seed 

#aggiungiamo gli embedding come features dei nodi
for n in G.nodes():
    G.nodes[n]['x']=np.array(node_embedding.forward().data[n])
    
pyg_graph = from_networkx(G)

  data[key] = torch.tensor(value)


In [4]:
transform = RandomLinkSplit(is_undirected=False,split_labels=True,
                      neg_sampling_ratio=1.0,
                      key = "edge_label",
                      disjoint_train_ratio=0,
                      num_val =0)
train_data, val_data, test_data = transform(pyg_graph)

In [5]:
train_data

Data(x=[19776, 128], edge_index=[2, 1118815], pos_edge_label=[1118815], pos_edge_label_index=[2, 1118815], neg_edge_label=[1118815], neg_edge_label_index=[2, 1118815])

## GAE

In [7]:
class GCNEncoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(GCNEncoder, self).__init__()
        self.conv1 = GCNConv(in_channels, 2 * out_channels, cached=True) # cached only for transductive learning
        self.conv2 = GCNConv(2 * out_channels, out_channels, cached=True) # cached only for transductive learning

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x

In [8]:
def train(data):
    model.train()
    optimizer.zero_grad()
    z = model.encode(data.x, data.edge_index)
    pos_edge_index=data.pos_edge_label_index
    neg_edge_index=data.neg_edge_label_index
    loss = model.recon_loss(z, pos_edge_index,neg_edge_index) 
    loss.backward()
    optimizer.step()
    return float(loss)


def test(data):
    model.eval()
    with torch.no_grad():
        z = model.encode(data.x,data.edge_index)
        pos_edge_index=data.pos_edge_label_index
        neg_edge_index=data.neg_edge_label_index
    return model.test(z, pos_edge_index, neg_edge_index)

In [9]:
# parameters
out_channels = 10   #embedding 
num_features = train_data.x.shape[1] 
epochs = 100


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# model
model = GAE(GCNEncoder(num_features, out_channels))
model = model.to(device)

# inizialize the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [10]:
writer = SummaryWriter('runs_4/GAE_experiment'+'10d_100_epochs')

In [11]:
for epoch in tqdm(range(1, epochs + 1)):
    loss = train(train_data)
    auc, ap = test(test_data)
    print('Epoch: {:03d}, AUC: {:.4f}, AP: {:.4f}'.format(epoch, auc, ap))
    
    writer.add_scalar('loss train',loss,epoch)
    writer.add_scalar('auc train',auc,epoch) 
    writer.add_scalar('ap train',ap,epoch)  

  1%|▊                                                                                 | 1/100 [00:01<02:51,  1.74s/it]

Epoch: 001, AUC: 0.7767, AP: 0.7167


  2%|█▋                                                                                | 2/100 [00:03<02:38,  1.62s/it]

Epoch: 002, AUC: 0.7938, AP: 0.7431


  3%|██▍                                                                               | 3/100 [00:04<02:32,  1.58s/it]

Epoch: 003, AUC: 0.8071, AP: 0.7634


  4%|███▎                                                                              | 4/100 [00:06<02:24,  1.51s/it]

Epoch: 004, AUC: 0.8199, AP: 0.7822


  5%|████                                                                              | 5/100 [00:07<02:25,  1.53s/it]

Epoch: 005, AUC: 0.8308, AP: 0.7983


  6%|████▉                                                                             | 6/100 [00:09<02:21,  1.50s/it]

Epoch: 006, AUC: 0.8395, AP: 0.8115


  7%|█████▋                                                                            | 7/100 [00:10<02:18,  1.49s/it]

Epoch: 007, AUC: 0.8465, AP: 0.8222


  8%|██████▌                                                                           | 8/100 [00:12<02:17,  1.49s/it]

Epoch: 008, AUC: 0.8517, AP: 0.8306


  9%|███████▍                                                                          | 9/100 [00:13<02:14,  1.48s/it]

Epoch: 009, AUC: 0.8555, AP: 0.8372


 10%|████████                                                                         | 10/100 [00:15<02:13,  1.49s/it]

Epoch: 010, AUC: 0.8580, AP: 0.8423


 11%|████████▉                                                                        | 11/100 [00:16<02:12,  1.49s/it]

Epoch: 011, AUC: 0.8598, AP: 0.8461


 12%|█████████▋                                                                       | 12/100 [00:18<02:14,  1.53s/it]

Epoch: 012, AUC: 0.8612, AP: 0.8491


 13%|██████████▌                                                                      | 13/100 [00:19<02:16,  1.57s/it]

Epoch: 013, AUC: 0.8622, AP: 0.8514


 14%|███████████▎                                                                     | 14/100 [00:21<02:19,  1.62s/it]

Epoch: 014, AUC: 0.8631, AP: 0.8534


 15%|████████████▏                                                                    | 15/100 [00:23<02:14,  1.59s/it]

Epoch: 015, AUC: 0.8636, AP: 0.8550


 16%|████████████▉                                                                    | 16/100 [00:24<02:10,  1.56s/it]

Epoch: 016, AUC: 0.8638, AP: 0.8561


 17%|█████████████▊                                                                   | 17/100 [00:26<02:08,  1.55s/it]

Epoch: 017, AUC: 0.8635, AP: 0.8567


 18%|██████████████▌                                                                  | 18/100 [00:27<02:10,  1.60s/it]

Epoch: 018, AUC: 0.8629, AP: 0.8567


 19%|███████████████▍                                                                 | 19/100 [00:29<02:14,  1.66s/it]

Epoch: 019, AUC: 0.8622, AP: 0.8567


 20%|████████████████▏                                                                | 20/100 [00:31<02:10,  1.64s/it]

Epoch: 020, AUC: 0.8623, AP: 0.8573


 21%|█████████████████                                                                | 21/100 [00:32<02:07,  1.62s/it]

Epoch: 021, AUC: 0.8633, AP: 0.8588


 22%|█████████████████▊                                                               | 22/100 [00:34<02:05,  1.61s/it]

Epoch: 022, AUC: 0.8648, AP: 0.8608


 23%|██████████████████▋                                                              | 23/100 [00:35<02:01,  1.58s/it]

Epoch: 023, AUC: 0.8662, AP: 0.8626


 24%|███████████████████▍                                                             | 24/100 [00:37<02:00,  1.59s/it]

Epoch: 024, AUC: 0.8670, AP: 0.8638


 25%|████████████████████▎                                                            | 25/100 [00:39<01:56,  1.55s/it]

Epoch: 025, AUC: 0.8675, AP: 0.8647


 26%|█████████████████████                                                            | 26/100 [00:40<01:55,  1.56s/it]

Epoch: 026, AUC: 0.8684, AP: 0.8660


 27%|█████████████████████▊                                                           | 27/100 [00:42<01:54,  1.57s/it]

Epoch: 027, AUC: 0.8701, AP: 0.8681


 28%|██████████████████████▋                                                          | 28/100 [00:43<01:51,  1.55s/it]

Epoch: 028, AUC: 0.8727, AP: 0.8709


 29%|███████████████████████▍                                                         | 29/100 [00:45<01:49,  1.55s/it]

Epoch: 029, AUC: 0.8753, AP: 0.8737


 30%|████████████████████████▎                                                        | 30/100 [00:46<01:47,  1.54s/it]

Epoch: 030, AUC: 0.8773, AP: 0.8758


 31%|█████████████████████████                                                        | 31/100 [00:48<01:44,  1.52s/it]

Epoch: 031, AUC: 0.8786, AP: 0.8772


 32%|█████████████████████████▉                                                       | 32/100 [00:49<01:42,  1.50s/it]

Epoch: 032, AUC: 0.8796, AP: 0.8784


 33%|██████████████████████████▋                                                      | 33/100 [00:51<01:40,  1.49s/it]

Epoch: 033, AUC: 0.8810, AP: 0.8799


 34%|███████████████████████████▌                                                     | 34/100 [00:52<01:36,  1.47s/it]

Epoch: 034, AUC: 0.8829, AP: 0.8818


 35%|████████████████████████████▎                                                    | 35/100 [00:54<01:34,  1.46s/it]

Epoch: 035, AUC: 0.8849, AP: 0.8836


 36%|█████████████████████████████▏                                                   | 36/100 [00:55<01:34,  1.48s/it]

Epoch: 036, AUC: 0.8862, AP: 0.8848


 37%|█████████████████████████████▉                                                   | 37/100 [00:56<01:32,  1.46s/it]

Epoch: 037, AUC: 0.8867, AP: 0.8853


 38%|██████████████████████████████▊                                                  | 38/100 [00:58<01:31,  1.48s/it]

Epoch: 038, AUC: 0.8870, AP: 0.8855


 39%|███████████████████████████████▌                                                 | 39/100 [00:59<01:28,  1.45s/it]

Epoch: 039, AUC: 0.8877, AP: 0.8860


 40%|████████████████████████████████▍                                                | 40/100 [01:01<01:26,  1.45s/it]

Epoch: 040, AUC: 0.8887, AP: 0.8870


 41%|█████████████████████████████████▏                                               | 41/100 [01:02<01:24,  1.43s/it]

Epoch: 041, AUC: 0.8897, AP: 0.8878


 42%|██████████████████████████████████                                               | 42/100 [01:04<01:22,  1.42s/it]

Epoch: 042, AUC: 0.8900, AP: 0.8880


 43%|██████████████████████████████████▊                                              | 43/100 [01:05<01:21,  1.43s/it]

Epoch: 043, AUC: 0.8898, AP: 0.8878


 44%|███████████████████████████████████▋                                             | 44/100 [01:07<01:20,  1.45s/it]

Epoch: 044, AUC: 0.8897, AP: 0.8877


 45%|████████████████████████████████████▍                                            | 45/100 [01:08<01:19,  1.45s/it]

Epoch: 045, AUC: 0.8901, AP: 0.8880


 46%|█████████████████████████████████████▎                                           | 46/100 [01:09<01:18,  1.45s/it]

Epoch: 046, AUC: 0.8907, AP: 0.8885


 47%|██████████████████████████████████████                                           | 47/100 [01:11<01:17,  1.46s/it]

Epoch: 047, AUC: 0.8908, AP: 0.8886


 48%|██████████████████████████████████████▉                                          | 48/100 [01:12<01:14,  1.44s/it]

Epoch: 048, AUC: 0.8904, AP: 0.8883


 49%|███████████████████████████████████████▋                                         | 49/100 [01:14<01:12,  1.43s/it]

Epoch: 049, AUC: 0.8900, AP: 0.8880


 50%|████████████████████████████████████████▌                                        | 50/100 [01:15<01:10,  1.42s/it]

Epoch: 050, AUC: 0.8901, AP: 0.8881


 51%|█████████████████████████████████████████▎                                       | 51/100 [01:17<01:09,  1.42s/it]

Epoch: 051, AUC: 0.8905, AP: 0.8885


 52%|██████████████████████████████████████████                                       | 52/100 [01:18<01:08,  1.42s/it]

Epoch: 052, AUC: 0.8905, AP: 0.8885


 53%|██████████████████████████████████████████▉                                      | 53/100 [01:19<01:07,  1.43s/it]

Epoch: 053, AUC: 0.8901, AP: 0.8883


 54%|███████████████████████████████████████████▋                                     | 54/100 [01:21<01:06,  1.44s/it]

Epoch: 054, AUC: 0.8899, AP: 0.8881


 55%|████████████████████████████████████████████▌                                    | 55/100 [01:22<01:05,  1.45s/it]

Epoch: 055, AUC: 0.8902, AP: 0.8884


 56%|█████████████████████████████████████████████▎                                   | 56/100 [01:24<01:03,  1.44s/it]

Epoch: 056, AUC: 0.8905, AP: 0.8888


 57%|██████████████████████████████████████████████▏                                  | 57/100 [01:25<01:02,  1.46s/it]

Epoch: 057, AUC: 0.8905, AP: 0.8888


 58%|██████████████████████████████████████████████▉                                  | 58/100 [01:27<01:00,  1.44s/it]

Epoch: 058, AUC: 0.8902, AP: 0.8886


 59%|███████████████████████████████████████████████▊                                 | 59/100 [01:28<00:59,  1.45s/it]

Epoch: 059, AUC: 0.8901, AP: 0.8886


 60%|████████████████████████████████████████████████▌                                | 60/100 [01:30<00:58,  1.46s/it]

Epoch: 060, AUC: 0.8904, AP: 0.8889


 61%|█████████████████████████████████████████████████▍                               | 61/100 [01:31<00:56,  1.45s/it]

Epoch: 061, AUC: 0.8906, AP: 0.8891


 62%|██████████████████████████████████████████████████▏                              | 62/100 [01:32<00:54,  1.43s/it]

Epoch: 062, AUC: 0.8905, AP: 0.8891


 63%|███████████████████████████████████████████████████                              | 63/100 [01:34<00:53,  1.45s/it]

Epoch: 063, AUC: 0.8903, AP: 0.8890


 64%|███████████████████████████████████████████████████▊                             | 64/100 [01:35<00:51,  1.44s/it]

Epoch: 064, AUC: 0.8905, AP: 0.8892


 65%|████████████████████████████████████████████████████▋                            | 65/100 [01:37<00:50,  1.44s/it]

Epoch: 065, AUC: 0.8908, AP: 0.8895


 66%|█████████████████████████████████████████████████████▍                           | 66/100 [01:38<00:48,  1.43s/it]

Epoch: 066, AUC: 0.8908, AP: 0.8895


 67%|██████████████████████████████████████████████████████▎                          | 67/100 [01:40<00:47,  1.45s/it]

Epoch: 067, AUC: 0.8907, AP: 0.8894


 68%|███████████████████████████████████████████████████████                          | 68/100 [01:41<00:46,  1.45s/it]

Epoch: 068, AUC: 0.8908, AP: 0.8895


 69%|███████████████████████████████████████████████████████▉                         | 69/100 [01:43<00:44,  1.44s/it]

Epoch: 069, AUC: 0.8911, AP: 0.8898


 70%|████████████████████████████████████████████████████████▋                        | 70/100 [01:44<00:42,  1.43s/it]

Epoch: 070, AUC: 0.8912, AP: 0.8899


 71%|█████████████████████████████████████████████████████████▌                       | 71/100 [01:45<00:41,  1.44s/it]

Epoch: 071, AUC: 0.8912, AP: 0.8899


 72%|██████████████████████████████████████████████████████████▎                      | 72/100 [01:47<00:40,  1.44s/it]

Epoch: 072, AUC: 0.8912, AP: 0.8899


 73%|███████████████████████████████████████████████████████████▏                     | 73/100 [01:48<00:38,  1.43s/it]

Epoch: 073, AUC: 0.8914, AP: 0.8901


 74%|███████████████████████████████████████████████████████████▉                     | 74/100 [01:50<00:37,  1.45s/it]

Epoch: 074, AUC: 0.8915, AP: 0.8902


 75%|████████████████████████████████████████████████████████████▊                    | 75/100 [01:51<00:36,  1.45s/it]

Epoch: 075, AUC: 0.8914, AP: 0.8901


 76%|█████████████████████████████████████████████████████████████▌                   | 76/100 [01:53<00:34,  1.43s/it]

Epoch: 076, AUC: 0.8914, AP: 0.8901


 77%|██████████████████████████████████████████████████████████████▎                  | 77/100 [01:54<00:32,  1.42s/it]

Epoch: 077, AUC: 0.8916, AP: 0.8903


 78%|███████████████████████████████████████████████████████████████▏                 | 78/100 [01:55<00:31,  1.43s/it]

Epoch: 078, AUC: 0.8917, AP: 0.8904


 79%|███████████████████████████████████████████████████████████████▉                 | 79/100 [01:57<00:29,  1.43s/it]

Epoch: 079, AUC: 0.8917, AP: 0.8903


 80%|████████████████████████████████████████████████████████████████▊                | 80/100 [01:58<00:28,  1.43s/it]

Epoch: 080, AUC: 0.8917, AP: 0.8903


 81%|█████████████████████████████████████████████████████████████████▌               | 81/100 [02:00<00:27,  1.44s/it]

Epoch: 081, AUC: 0.8918, AP: 0.8905


 82%|██████████████████████████████████████████████████████████████████▍              | 82/100 [02:01<00:25,  1.43s/it]

Epoch: 082, AUC: 0.8919, AP: 0.8905


 83%|███████████████████████████████████████████████████████████████████▏             | 83/100 [02:03<00:24,  1.44s/it]

Epoch: 083, AUC: 0.8918, AP: 0.8904


 84%|████████████████████████████████████████████████████████████████████             | 84/100 [02:04<00:22,  1.44s/it]

Epoch: 084, AUC: 0.8919, AP: 0.8905


 85%|████████████████████████████████████████████████████████████████████▊            | 85/100 [02:05<00:21,  1.43s/it]

Epoch: 085, AUC: 0.8920, AP: 0.8906


 86%|█████████████████████████████████████████████████████████████████████▋           | 86/100 [02:07<00:20,  1.44s/it]

Epoch: 086, AUC: 0.8919, AP: 0.8906


 87%|██████████████████████████████████████████████████████████████████████▍          | 87/100 [02:08<00:18,  1.46s/it]

Epoch: 087, AUC: 0.8919, AP: 0.8906


 88%|███████████████████████████████████████████████████████████████████████▎         | 88/100 [02:10<00:17,  1.46s/it]

Epoch: 088, AUC: 0.8920, AP: 0.8907


 89%|████████████████████████████████████████████████████████████████████████         | 89/100 [02:11<00:16,  1.47s/it]

Epoch: 089, AUC: 0.8920, AP: 0.8907


 90%|████████████████████████████████████████████████████████████████████████▉        | 90/100 [02:13<00:14,  1.47s/it]

Epoch: 090, AUC: 0.8919, AP: 0.8907


 91%|█████████████████████████████████████████████████████████████████████████▋       | 91/100 [02:14<00:13,  1.46s/it]

Epoch: 091, AUC: 0.8919, AP: 0.8907


 92%|██████████████████████████████████████████████████████████████████████████▌      | 92/100 [02:16<00:11,  1.45s/it]

Epoch: 092, AUC: 0.8919, AP: 0.8908


 93%|███████████████████████████████████████████████████████████████████████████▎     | 93/100 [02:17<00:10,  1.47s/it]

Epoch: 093, AUC: 0.8919, AP: 0.8907


 94%|████████████████████████████████████████████████████████████████████████████▏    | 94/100 [02:19<00:08,  1.46s/it]

Epoch: 094, AUC: 0.8919, AP: 0.8907


 95%|████████████████████████████████████████████████████████████████████████████▉    | 95/100 [02:20<00:07,  1.46s/it]

Epoch: 095, AUC: 0.8919, AP: 0.8908


 96%|█████████████████████████████████████████████████████████████████████████████▊   | 96/100 [02:21<00:05,  1.44s/it]

Epoch: 096, AUC: 0.8919, AP: 0.8908


 97%|██████████████████████████████████████████████████████████████████████████████▌  | 97/100 [02:23<00:04,  1.45s/it]

Epoch: 097, AUC: 0.8918, AP: 0.8907


 98%|███████████████████████████████████████████████████████████████████████████████▍ | 98/100 [02:24<00:02,  1.45s/it]

Epoch: 098, AUC: 0.8919, AP: 0.8908


 99%|████████████████████████████████████████████████████████████████████████████████▏| 99/100 [02:26<00:01,  1.44s/it]

Epoch: 099, AUC: 0.8919, AP: 0.8908


100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [02:27<00:00,  1.48s/it]

Epoch: 100, AUC: 0.8918, AP: 0.8908





## DeepGAE

In [42]:
class DeepGCNEncoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DeepGCNEncoder, self).__init__()
        self.conv1 = GCNConv(in_channels, 2 * out_channels, cached=True)
        self.conv2 = GCNConv(2 * out_channels, 2 * out_channels, cached=True)
        self.conv3 = GCNConv(2 * out_channels, out_channels, cached=True)

    def forward(self, x, edge_index,edge_weight=None):
        x = self.conv1(x, edge_index,edge_weight=edge_weight).relu()
        x = self.conv2(x, edge_index,edge_weight=edge_weight).relu()
        return self.conv3(x, edge_index,edge_weight=edge_weight)

In [48]:
# parameters
out_channels = 20   #embedding 
num_features = train_data.x.shape[1] 
epochs = 100


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# model
model = GAE(DeepGCNEncoder(num_features, out_channels))
model = model.to(device)

# inizialize the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [49]:
writer = SummaryWriter('runs_3/DeepGAE_experiment'+'20d_100_epochs')

for epoch in tqdm(range(1, epochs + 1)):
    loss = train(train_data)
    auc, ap = test(test_data)
    #print('Epoch: {:03d}, AUC: {:.4f}, AP: {:.4f}'.format(epoch, auc, ap))
    
    writer.add_scalar('loss train',loss,epoch)
    writer.add_scalar('auc train',auc,epoch) 
    writer.add_scalar('ap train',ap,epoch) 

100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [08:21<00:00,  5.02s/it]


## VGAE

In [50]:
class VariationalGCNEncoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(VariationalGCNEncoder, self).__init__()
        self.conv1 = GCNConv(in_channels, 2 * out_channels, cached=True)
        self.conv_mu = GCNConv(2 * out_channels, out_channels, cached=True)
        self.conv_logstd = GCNConv(2 * out_channels, out_channels, cached=True)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        return self.conv_mu(x, edge_index), self.conv_logstd(x, edge_index)

In [54]:
out_channels = 20   #embedding 
num_features = train_data.x.shape[1] 
epochs = 100

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = VGAE(VariationalGCNEncoder(num_features, out_channels)) 
model = model.to(device)


optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [52]:
def train_VGAE(data):
    model.train()
    optimizer.zero_grad()
    z = model.encode(data.x, data.edge_index)
    pos_edge_index=data.pos_edge_label_index
    neg_edge_index=data.neg_edge_label_index
    loss = model.recon_loss(z, pos_edge_index,neg_edge_index) 
    loss = loss + (1 / data.x.shape[0]) * model.kl_loss()
    loss.backward()
    optimizer.step()
    return float(loss)

In [55]:
writer = SummaryWriter('runs_3/VGAE_experiment'+'20d_100_epochs')

for epoch in tqdm(range(1, epochs + 1)):
    loss = train_VGAE(train_data)
    auc, ap = test(test_data)
    #print('Epoch: {:03d}, AUC: {:.4f}, AP: {:.4f}'.format(epoch, auc, ap))
    
    writer.add_scalar('loss train',loss,epoch)
    writer.add_scalar('auc train',auc,epoch) 
    writer.add_scalar('ap train',ap,epoch) 

100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [06:25<00:00,  3.85s/it]


## ARGVA

In [59]:
class VariationalGCNEncoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(VariationalGCNEncoder, self).__init__()
        self.conv1 = GCNConv(in_channels, 2 * out_channels, cached=True)
        self.conv_mu = GCNConv(2 * out_channels, out_channels, cached=True)
        self.conv_logstd = GCNConv(2 * out_channels, out_channels, cached=True)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        return self.conv_mu(x, edge_index), self.conv_logstd(x, edge_index)
    
class Discriminator(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.lin1 = Linear(in_channels, hidden_channels)
        self.lin2 = Linear(hidden_channels, hidden_channels)
        self.lin3 = Linear(hidden_channels, out_channels)

    def forward(self, x):
        x = self.lin1(x).relu()
        x = self.lin2(x).relu()
        return self.lin3(x)

In [60]:
def train_ARGVA(data):
    model.train()
    encoder_optimizer.zero_grad()
    z = model.encode(data.x, data.edge_index)
    pos_edge_index=data.pos_edge_label_index
    neg_edge_index=data.neg_edge_label_index
    
    for i in range(5):  
        #discriminator.train()
        discriminator_optimizer.zero_grad()
        discriminator_loss = model.discriminator_loss(z)
        discriminator_loss.backward()
        discriminator_optimizer.step()
 
    loss = model.recon_loss(z, pos_edge_index,neg_edge_index) 
    loss = loss + model.reg_loss(z)
    loss = loss + (1 / data.x.shape[0]) * model.kl_loss()
    
    loss.backward()
    encoder_optimizer.step()
    return float(loss)

In [63]:
embedding = 20   #embedding 
num_features = train_data.x.shape[1] 
epochs = 100

encoder = VariationalGCNEncoder(num_features, embedding)

discriminator = Discriminator(in_channels=embedding, hidden_channels=embedding//2, 
                              out_channels=1)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = ARGVA(encoder, discriminator)
model = model.to(device)

discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.001)
encoder_optimizer = torch.optim.Adam(encoder.parameters(), lr=0.005)

In [64]:
writer = SummaryWriter('runs_3/ARGVA_experiment'+'20d_100_epochs')

for epoch in tqdm(range(1, epochs + 1)):
    loss = train_ARGVA(train_data)
    auc, ap = test(test_data)
    #print('Epoch: {:03d}, AUC: {:.4f}, AP: {:.4f}'.format(epoch, auc, ap))
    
    writer.add_scalar('loss train',loss,epoch)
    writer.add_scalar('auc train',auc,epoch) 
    writer.add_scalar('ap train',ap,epoch) 

100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [05:46<00:00,  3.46s/it]


## GAE with Linear Decoder

In [13]:
class GCNDecoder(torch.nn.Module):
    def __init__(self, latent_dim):
        super(GCNDecoder, self).__init__()
        self.lin1 = Linear(latent_dim,latent_dim)
        self.lin2 = Linear(latent_dim,latent_dim//2)
        self.lin3 = Linear(latent_dim//2,1)

    def forward(self, z, edge_index, sigmoid=True):

        z = (z[edge_index[0]] * z[edge_index[1]])#.sum(dim=1)
        z = self.lin1(z).relu()
        z = self.lin2(z).relu()
        z = self.lin3(z)
        z = z.squeeze()
        
        return torch.sigmoid(z) if sigmoid else value    

In [14]:
# parameters
out_channels = 10   #embedding 
num_features = train_data.x.shape[1] 
epochs = 100

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = GAE(GCNEncoder(num_features, out_channels),GCNDecoder(out_channels))
model = model.to(device)

# inizialize the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [15]:
writer = SummaryWriter('runs_4/GAE+dec_experiment'+'10d_100_epochs')

for epoch in tqdm(range(1, epochs + 1)):
    loss = train(train_data)
    auc, ap = test(test_data)
    #print('Epoch: {:03d}, AUC: {:.4f}, AP: {:.4f}'.format(epoch, auc, ap))
    
    writer.add_scalar('loss train',loss,epoch)
    writer.add_scalar('auc train',auc,epoch) 
    writer.add_scalar('ap train',ap,epoch) 

100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [02:51<00:00,  1.72s/it]


## DeepGAE with Linear Decoder

In [125]:
# parameters
out_channels = 20   #embedding 
num_features = train_data.x.shape[1] 
epochs = 100

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = GAE(DeepGCNEncoder(num_features, out_channels),GCNDecoder(out_channels))
model = model.to(device)

# inizialize the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [126]:
writer = SummaryWriter('runs_3/DeepGAE+dec_experiment'+'20d_100_epochs')

for epoch in tqdm(range(1, epochs + 1)):
    loss = train(train_data)
    auc, ap = test(test_data)
    #print('Epoch: {:03d}, AUC: {:.4f}, AP: {:.4f}'.format(epoch, auc, ap))
    
    writer.add_scalar('loss train',loss,epoch)
    writer.add_scalar('auc train',auc,epoch) 
    writer.add_scalar('ap train',ap,epoch) 

100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [07:59<00:00,  4.80s/it]


## VGAE with Linear Decoder

In [129]:
out_channels = 20   #embedding 
num_features = train_data.x.shape[1] 
epochs = 100

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = VGAE(VariationalGCNEncoder(num_features, out_channels),GCNDecoder(out_channels)) 
model = model.to(device)


optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [130]:
writer = SummaryWriter('runs_3/VGAE+dec_experiment'+'20d_100_epochs')

for epoch in tqdm(range(1, epochs + 1)):
    loss = train_VGAE(train_data)
    auc, ap = test(test_data)
    #print('Epoch: {:03d}, AUC: {:.4f}, AP: {:.4f}'.format(epoch, auc, ap))
    
    writer.add_scalar('loss train',loss,epoch)
    writer.add_scalar('auc train',auc,epoch) 
    writer.add_scalar('ap train',ap,epoch) 

100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [06:33<00:00,  3.94s/it]


## ARGVA with linear Decoder (AUC e AP basse)

In [147]:
class Discriminator_sig(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.lin1 = Linear(in_channels, hidden_channels)
        self.lin2 = Linear(hidden_channels, hidden_channels)
        self.lin3 = Linear(hidden_channels, out_channels)

    def forward(self, x):
        x = self.lin1(x).relu()
        x = self.lin2(x).relu()
        return torch.sigmoid(self.lin3(x))#added sigmoid

In [145]:
embedding = 10   #embedding 
num_features = train_data.x.shape[1] 
epochs = 100

encoder = VariationalGCNEncoder(num_features, embedding)

#discriminator = Discriminator(in_channels=embedding, hidden_channels=embedding//2, 
#                              out_channels=1)
    
discriminator = Discriminator_sig(in_channels=embedding, hidden_channels=embedding//2, 
                              out_channels=1)
decoder=GCNDecoder(embedding)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = ARGVA(encoder, discriminator,decoder)
model = model.to(device)

discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.001)#0.001
encoder_optimizer = torch.optim.Adam(encoder.parameters(), lr=0.005)#0.005

In [None]:
writer = SummaryWriter('runs_3/ARGVAsig+dec_experiment'+'10d_100_epochs')

for epoch in tqdm(range(1, epochs + 1)):
    loss = train_ARGVA(train_data)
    auc, ap = test(test_data)
    #print('Epoch: {:03d}, AUC: {:.4f}, AP: {:.4f}'.format(epoch, auc, ap))
    
    writer.add_scalar('loss train',loss,epoch)
    writer.add_scalar('auc train',auc,epoch) 
    writer.add_scalar('ap train',ap,epoch) 
writer.close()