-
Notifications
You must be signed in to change notification settings - Fork 1
/
models.py
149 lines (114 loc) · 5.53 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
from sklearn.metrics import roc_auc_score, average_precision_score
import torch
from torch_geometric.utils import negative_sampling, remove_self_loops, add_self_loops
from layers import InnerProductDecoder, DirectedInnerProductDecoder
from initializations import reset
EPS = 1e-15
MAX_LOGSTD = 10
class GAE(torch.nn.Module):
r"""The Graph Auto-Encoder model from the
`"Variational Graph Auto-Encoders" <https://arxiv.org/abs/1611.07308>`_
paper based on user-defined encoder and decoder models.
Args:
encoder (Module): The encoder module.
decoder (Module, optional): The decoder module. If set to :obj:`None`,
will default to the
:class:`torch_geometric.nn.models.InnerProductDecoder`.
(default: :obj:`None`)
"""
def __init__(self, encoder, decoder=None):
super(GAE, self).__init__()
self.encoder = encoder
self.decoder = InnerProductDecoder() if decoder is None else decoder
GAE.reset_parameters(self)
def reset_parameters(self):
reset(self.encoder)
reset(self.decoder)
def encode(self, *args, **kwargs):
r"""Runs the encoder and computes node-wise latent variables."""
return self.encoder(*args, **kwargs)
def decode(self, *args, **kwargs):
r"""Runs the decoder and computes edge probabilities."""
return self.decoder(*args, **kwargs)
def recon_loss(self, z, pos_edge_index, neg_edge_index=None):
r"""Given latent variables :obj:`z`, computes the binary cross
entropy loss for positive edges :obj:`pos_edge_index` and negative
sampled edges.
Args:
z (Tensor): The latent space :math:`\mathbf{Z}`.
pos_edge_index (LongTensor): The positive edges to train against.
neg_edge_index (LongTensor, optional): The negative edges to train
against. If not given, uses negative sampling to calculate
negative edges. (default: :obj:`None`)
"""
pos_loss = -torch.log(
self.decoder(z, pos_edge_index, sigmoid=True) + EPS).mean()
# Do not include self-loops in negative samples
pos_edge_index, _ = remove_self_loops(pos_edge_index)
pos_edge_index, _ = add_self_loops(pos_edge_index)
if neg_edge_index is None:
neg_edge_index = negative_sampling(pos_edge_index, z.size(0))
neg_loss = -torch.log(1 -
self.decoder(z, neg_edge_index, sigmoid=True) +
EPS).mean()
return pos_loss + neg_loss
def test(self, z, pos_edge_index, neg_edge_index):
r"""Given latent variables :obj:`z`, positive edges
:obj:`pos_edge_index` and negative edges :obj:`neg_edge_index`,
computes area under the ROC curve (AUC) and average precision (AP)
scores.
Args:
z (Tensor): The latent space :math:`\mathbf{Z}`.
pos_edge_index (LongTensor): The positive edges to evaluate
against.
neg_edge_index (LongTensor): The negative edges to evaluate
against.
"""
pos_y = z.new_ones(pos_edge_index.size(1))
neg_y = z.new_zeros(neg_edge_index.size(1))
y = torch.cat([pos_y, neg_y], dim=0)
pos_pred = self.decoder(z, pos_edge_index, sigmoid=True)
neg_pred = self.decoder(z, neg_edge_index, sigmoid=True)
pred = torch.cat([pos_pred, neg_pred], dim=0)
y, pred = y.detach().cpu().numpy(), pred.detach().cpu().numpy()
return roc_auc_score(y, pred), average_precision_score(y, pred)
class DirectedGAE(torch.nn.Module):
def __init__(self, encoder, decoder=None):
super(DirectedGAE, self).__init__()
self.encoder = encoder
self.decoder = DirectedInnerProductDecoder() if decoder is None else decoder
DirectedGAE.reset_parameters(self)
def reset_parameters(self):
reset(self.encoder)
reset(self.decoder)
def forward(self, data):
x, edge_index, edge_weight = data.x, data.edge_index, data.edge_weight
s, t = self.encoder(x, x, edge_index)
adj_pred = self.decoder.forward_all(s, t)
return adj_pred
def encode(self, *args, **kwargs):
return self.encoder(*args, **kwargs)
def decode(self, *args, **kwargs):
return self.decoder(*args, **kwargs)
def recon_loss(self, s, t, pos_edge_index, neg_edge_index=None):
pos_loss = -torch.log(
self.decoder(s, t, pos_edge_index, sigmoid=True) + EPS).mean()
# Do not include self-loops in negative samples
pos_edge_index, _ = remove_self_loops(pos_edge_index)
pos_edge_index, _ = add_self_loops(pos_edge_index)
if neg_edge_index is None:
neg_edge_index = negative_sampling(pos_edge_index, s.size(0))
neg_loss = -torch.log(1 -
self.decoder(s, t, neg_edge_index, sigmoid=True) +
EPS).mean()
return pos_loss + neg_loss
def test(self, s, t, pos_edge_index, neg_edge_index):
# XXX
pos_y = s.new_ones(pos_edge_index.size(1))
neg_y = s.new_zeros(neg_edge_index.size(1))
y = torch.cat([pos_y, neg_y], dim=0)
pos_pred = self.decoder(s, t, pos_edge_index, sigmoid=True)
neg_pred = self.decoder(s, t, neg_edge_index, sigmoid=True)
pred = torch.cat([pos_pred, neg_pred], dim=0)
y, pred = y.detach().cpu().numpy(), pred.detach().cpu().numpy()
return roc_auc_score(y, pred), average_precision_score(y, pred)