In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable

import random

In [2]:
class MeanAggregator(nn.Module):
    """
    Aggregates a node's embeddings using mean of neighbors' embeddings
    """
    def __init__(self, features, cuda=False, gcn=False): 
        """
        Initializes the aggregator for a specific graph.

        features -- function mapping LongTensor of node ids to FloatTensor of feature values.
        cuda -- whether to use GPU
        gcn --- whether to perform concatenation GraphSAGE-style, or add self-loops GCN-style
        """

        super(MeanAggregator, self).__init__()

        self.features = features
        self.cuda = cuda
        self.gcn = gcn
        
    def forward(self, nodes, to_neighs, num_sample=10):
        """
        nodes --- list of nodes in a batch
        to_neighs --- list of sets, each set is the set of neighbors for node in batch
        num_sample --- number of neighbors to sample. No sampling if None.
        """
        # sample
        _set = set # Local pointers to functions (speed hack)
        if not num_sample is None:
            _sample = random.sample
            samp_neighs = [_set(_sample(to_neigh, 
                            num_sample,
                            )) if len(to_neigh) >= num_sample else to_neigh for to_neigh in to_neighs]
        else:
            samp_neighs = to_neighs
        # gcn: self-loop, if not then cat
        if self.gcn:
            samp_neighs = [samp_neigh + set([nodes[i]]) for i, samp_neigh in enumerate(samp_neighs)]
        # mask: adj matrix after sampling
        unique_nodes_list = list(set.union(*samp_neighs))
        unique_nodes = {n:i for i,n in enumerate(unique_nodes_list)}
        mask = Variable(torch.zeros(len(samp_neighs), len(unique_nodes)))
        column_indices = [unique_nodes[n] for samp_neigh in samp_neighs for n in samp_neigh]   
        row_indices = [i for i in range(len(samp_neighs)) for j in range(len(samp_neighs[i]))]
        mask[row_indices, column_indices] = 1
        
        if self.cuda:
            mask = mask.cuda()
        # normalize mask
        num_neigh = mask.sum(1, keepdim=True)
        mask = mask.div(num_neigh)
        # sampled features
        if self.cuda:
            embed_matrix = self.features(torch.LongTensor(unique_nodes_list).cuda())
        else:
            embed_matrix = self.features(torch.LongTensor(unique_nodes_list))
        # mean aggregation
        to_feats = mask.mm(embed_matrix)
        return to_feats

In [3]:
"""sample num_sample neighbors"""
# if num_neign >= num_sample, sample num_sample neighbors
# else, keep all neighbors
to_neighs = [set([1]), set([2, 3]), set([1, 3]), set([1, 2])]
num_sample = 1
samp_neighs = [set(random.sample(to_neigh, num_sample)) 
               if len(to_neigh) >= num_sample else to_neigh for to_neigh in to_neighs]
print(samp_neighs)

[{1}, {2}, {3}, {1}]


In [4]:
"""gcn: self-loop"""
gcn = False
if gcn:
    samp_neighs = [set(list(samp_neigh) + [i]) for i, samp_neigh in enumerate(samp_neighs) ]
print(samp_neighs)

[{1}, {2}, {3}, {1}]


In [5]:
# union all sampled neighbors
unique_nodes_list = list(set.union(*samp_neighs))
print(unique_nodes_list)
# node_id: rownum pairs
unique_nodes = {n:i for i,n in enumerate(unique_nodes_list)}
# mask: num_nodes * num_all_neighbors tensor
mask = torch.zeros(len(samp_neighs), len(unique_nodes))
# colunm: rownum of all sampled neighbors
column_indices = [unique_nodes[n] for samp_neigh in samp_neighs for n in samp_neigh]
print(column_indices)
# row: len(samp_neighs[i]) times node i 
row_indices = [i for i in range(len(samp_neighs)) for j in range(len(samp_neighs[i]))]
print(row_indices)
# mask: adj matrix after sampling
mask[row_indices, column_indices] = 1
print(mask)

[1, 2, 3]
[0, 1, 2, 0]
[0, 1, 2, 3]
tensor([[1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.],
        [1., 0., 0.]])


In [6]:
num_neigh = mask.sum(1, keepdim=True)
mask = mask.div(num_neigh)
print(mask)

tensor([[1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.],
        [1., 0., 0.]])


In [7]:
# features after sampling
features = torch.tensor([[1.], [2.], [3.], [4.]])
print(features)
embed_features = features[torch.tensor(unique_nodes_list)]
print(embed_features)
# mean aggregation
to_feats = mask.mm(embed_features)
print(to_feats)

tensor([[1.],
        [2.],
        [3.],
        [4.]])
tensor([[2.],
        [3.],
        [4.]])
tensor([[2.],
        [3.],
        [4.],
        [2.]])


In [8]:
import torch
import torch.nn as nn
from torch.nn import init
import torch.nn.functional as F

In [9]:
class Encoder(nn.Module):
    """
    Encodes a node's using 'convolutional' GraphSage approach
    """
    def __init__(self, features, feature_dim, 
            embed_dim, adj_lists, aggregator,
            num_sample=10,
            base_model=None, gcn=False, cuda=False, 
            feature_transform=False): 
        super(Encoder, self).__init__()

        self.features = features # all features
        self.feat_dim = feature_dim
        self.adj_lists = adj_lists # all adj lists
        self.aggregator = aggregator
        self.num_sample = num_sample
        if base_model != None:
            self.base_model = base_model

        self.gcn = gcn
        self.embed_dim = embed_dim
        self.cuda = cuda
        self.aggregator.cuda = cuda
        self.weight = nn.Parameter(
                torch.FloatTensor(embed_dim, self.feat_dim if self.gcn else 2 * self.feat_dim))
        init.xavier_uniform_(self.weight)

    def forward(self, nodes):
        """
        Generates embeddings for a batch of nodes.

        nodes     -- list of nodes
        """
        # aggregation: batch of nodes
        neigh_feats = self.aggregator.forward(nodes, [self.adj_lists[int(node)] for node in nodes], 
                self.num_sample)
        if not self.gcn:
            if self.cuda:
                self_feats = self.features(torch.LongTensor(nodes).cuda())
            else:
                self_feats = self.features(torch.LongTensor(nodes))
            combined = torch.cat([self_feats, neigh_feats], dim=1)
        else:
            combined = neigh_feats
        # one layer encoding: embed_dim * num_batch_nodes
        combined = F.relu(self.weight.mm(combined.t()))
        return combined

In [10]:
"""graphsage model"""
class SupervisedGraphSage(nn.Module):

    def __init__(self, num_classes, enc):
        super(SupervisedGraphSage, self).__init__()
        self.enc = enc
        self.xent = nn.CrossEntropyLoss()

        self.weight = nn.Parameter(torch.FloatTensor(num_classes, enc.embed_dim))
        init.xavier_uniform_(self.weight)

    def forward(self, nodes):
        embeds = self.enc(nodes)
        scores = self.weight.mm(embeds)
        return scores.t()

In [11]:
"""model"""
# random data
feat_data = torch.rand([2708, 1433])
adj_lists = [set(random.sample(range(2708), 15)) for _ in range(2708)]
# input
features = nn.Embedding(2708, 1433)
features.weight = nn.Parameter(torch.FloatTensor(feat_data), requires_grad=False)
# layer 1
agg1 = MeanAggregator(features)
enc1 = Encoder(features, 1433, 128, adj_lists, agg1, gcn=True)
# layer 2
agg2 = MeanAggregator(lambda nodes : enc1(nodes).t())
enc2 = Encoder(lambda nodes : enc1(nodes).t(), enc1.embed_dim, 128, adj_lists, agg2,
        base_model=enc1, gcn=True)
# model
graphsage = SupervisedGraphSage(7, enc2)