In [1]:
import dgl
import dgl.nn as gnn

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

Using backend: pytorch


# Edge classification with neighborhood sampling

Обучение для задачи предсказания на ребрах похоже на предсказание на узлах. 
1. Для сэмплинга соседей используются все те же самые сэмплеры
2. Вместо `NodeDataLoader` используем `EdgeDataLoader`, который итерируется по батчам из ребер. Он возвращает:
* `input_nodes` - узлы, которые необходимы для расчетов по батчу
* `edge_subgraph` - подграф на основе ребер из батча
* `blocks` - MFGs для вычислений по слоям
3. Иногда требуется удалить из графа вычислений те ребра, на которых происходит обучение, иначе модель в теории может использовать факт наличия ребра. В `DGL` мы можем удалить ребра, попавшие в батч, из оригинального графа перед сэмплингом соседей (а также обратные ребра, если это нужно). __Вопрос__: убирает ли он что-то по умолчанию?

Как обычно, модель для задачи классификации ребер состоит из двух частей:
1. Получение представлений для узлов
2. Расчет оценки для ребра на основе представлений инцидентных узлов

In [2]:
# модель из 10_dgl_stochastic_node_classification
class Conv(nn.Module):
    def __init__(self, in_features, hidden_features):
        super().__init__()
        self.conv1 = gnn.GraphConv(in_features, hidden_features)
        self.conv2 = gnn.GraphConv(hidden_features, hidden_features)

    def forward(self, blocks, x):
        x = F.relu(self.conv1(blocks[0], x))
        x = self.conv2(blocks[1], x)
        return x  

# обычный MLPPredictor, в forward вместо всего графа
# придет подграф, созданный на основе батча ребер
class MLPPredictor(nn.Module):
    def __init__(self, n_inputs, n_classes):
        super().__init__()
        self.linear = nn.Linear(2 * n_inputs, n_classes)

    def get_score(self, edges):
        data = torch.cat([edges.src['h'], edges.dst['h']], dim=1)
        return {'score': self.linear(data)}

    def forward(self, edge_subgraph, features):
        with edge_subgraph.local_scope():
            edge_subgraph.ndata['h'] = features
            edge_subgraph.apply_edges(self.get_score)
            return edge_subgraph.edata['score']  

In [3]:
class GCN(nn.Module):
    def __init__(self, n_inputs, n_hidden, n_classes):
        super().__init__()
        self.conv = Conv(n_inputs, n_hidden)
        self.predictor = MLPPredictor(n_hidden, n_classes)

    def forward(self, edge_subgraph, blocks, features):
        # blocks - для расчета по всем нужным узлам
        # edge_subgraph - для расчета по всем нужным связям
        out = self.conv(blocks, features)
        return self.predictor(edge_subgraph, out)

In [4]:
from utils import create_edge_pred_graph

n_nodes, n_edges, n_node_features,n_edge_features = 50, 100, 10, 10
G = create_edge_pred_graph(n_nodes, n_edges, n_node_features, n_edge_features)

node_features = G.ndata['feature']
edge_labels = G.edata['label_class'].long()
train_mask = G.edata['train_mask']
train_eids = train_mask.nonzero().flatten()

n_classes = len(edge_labels.unique())

In [5]:
model = GCN(n_node_features, 16, n_classes)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=.001)

sampler = dgl.dataloading.MultiLayerFullNeighborSampler(n_layers=2)
dataloader = dgl.dataloading.EdgeDataLoader(G, train_eids, sampler, batch_size=32, shuffle=True)

for epoch in range(20):
    for step, (input_nodes, edge_subgraph, blocks) in enumerate(dataloader):
        batch_features = blocks[0].srcdata['feature']
        # batch_labels = edge_labels[edge_subgraph.edata['_ID']]
        batch_labels = edge_subgraph.edata['label_class'].long()

        preds = model(edge_subgraph, blocks, batch_features)
        loss = criterion(preds, batch_labels)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if not step % 5:
            acc = (preds.argmax(dim=1) == batch_labels).sum() / len(preds)
            print('Epoch {:05d} | Step {:05d} | Loss {:.4f} | Train Acc {:.4f}'.format(
                        epoch, step, loss.item(), acc.item()))


Epoch 00000 | Step 00000 | Loss 0.6714 | Train Acc 0.5625
Epoch 00001 | Step 00000 | Loss 0.6803 | Train Acc 0.5000
Epoch 00002 | Step 00000 | Loss 0.6657 | Train Acc 0.5938
Epoch 00003 | Step 00000 | Loss 0.6818 | Train Acc 0.5312
Epoch 00004 | Step 00000 | Loss 0.6566 | Train Acc 0.6250
Epoch 00005 | Step 00000 | Loss 0.6461 | Train Acc 0.6875
Epoch 00006 | Step 00000 | Loss 0.6125 | Train Acc 0.7188
Epoch 00007 | Step 00000 | Loss 0.6389 | Train Acc 0.6875
Epoch 00008 | Step 00000 | Loss 0.6505 | Train Acc 0.6250
Epoch 00009 | Step 00000 | Loss 0.6558 | Train Acc 0.7188
Epoch 00010 | Step 00000 | Loss 0.6794 | Train Acc 0.5938
Epoch 00011 | Step 00000 | Loss 0.6458 | Train Acc 0.7500
Epoch 00012 | Step 00000 | Loss 0.6351 | Train Acc 0.6875
Epoch 00013 | Step 00000 | Loss 0.6460 | Train Acc 0.6562
Epoch 00014 | Step 00000 | Loss 0.6219 | Train Acc 0.7500
Epoch 00015 | Step 00000 | Loss 0.6211 | Train Acc 0.7188
Epoch 00016 | Step 00000 | Loss 0.6231 | Train Acc 0.7500
Epoch 00017 | 

## Heterogenious graph stochastic training

In [6]:
# копия из 10_dgl_stochastic_node_classification
class RGCN(nn.Module):
    def __init__(self, n_inputs, n_hidden, rel_names):
        super().__init__()
        conv1_modules = {rel: gnn.GraphConv(n_inputs, n_hidden) for rel in rel_names}
        conv2_modules = {rel: gnn.GraphConv(n_hidden, n_hidden) for rel in rel_names}
        self.conv1 = gnn.HeteroGraphConv(conv1_modules, aggregate='sum')
        self.conv2 = gnn.HeteroGraphConv(conv2_modules, aggregate='sum')

    def forward(self, blocks, features):
        out = self.conv1(blocks[0], features)
        out = {k: F.relu(v) for k, v in out.items()}
        out = self.conv2(blocks[1], out)
        return out

class ScorePredictor(nn.Module):
    def __init__(self, n_inputs, n_classes):
        super().__init__()
        self.W = nn.Linear(2 * n_inputs, n_classes)

    def apply_edges(self, edges):
        data = torch.cat([edges.src['h'], edges.dst['h']], dim=1)
        return {'score': self.W(data)}

    def forward(self, edge_subgraph, features):
        with edge_subgraph.local_scope():
            edge_subgraph.ndata['h'] = features
            # итерируемся по всем типам ребер, чтобы получить для всех них оценки
            for etype in edge_subgraph.canonical_etypes:
                edge_subgraph.apply_edges(self.apply_edges, etype=etype)
            return edge_subgraph.edata['score']

In [14]:
class GCN(nn.Module):
    def __init__(self, n_inputs, n_hidden, n_classes, etypes):
        super().__init__()
        self.rgcn = RGCN(n_inputs, n_hidden, etypes)
        self.pred = ScorePredictor(n_hidden, n_classes)

    def forward(self, edge_subgraph, blocks, x):
        x = self.rgcn(blocks, x)
        return self.pred(edge_subgraph, x)

In [37]:
from utils import create_heterograph

G = create_heterograph()

n_hetero_features = 10
rel_names = G.etypes

train_mask = G.edges['click'].data['train_mask']
train_eids = {'click': G.edges['click'].data['train_mask'].nonzero().flatten()}

n_classes = 4

In [40]:
model = GCN(n_hetero_features, 16, n_classes, rel_names)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=.001)

n_edges = G.num_edges()
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(n_layers=2)
dataloader = dgl.dataloading.EdgeDataLoader(G, train_eids, sampler, batch_size=128, shuffle=True)

for epoch in range(20):
    for step, (input_nodes, edge_subgraph, blocks) in enumerate(dataloader):
        batch_features = blocks[0].srcdata['feature']
        # синтетические классы
        batch_labels = (edge_subgraph.edges['click'].data['label'] % 4).long()

        preds = model(edge_subgraph, blocks, batch_features)
        preds_click = preds[('user', 'click', 'item')]
        loss = criterion(preds_click, batch_labels)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if not step % 5:
            acc = (preds_click.argmax(dim=1) == batch_labels).sum() / len(preds_click)
            print('Epoch {:05d} | Step {:05d} | Loss {:.4f} | Train Acc {:.4f}'.format(
                        epoch, step, loss.item(), acc.item()))


Epoch 00000 | Step 00000 | Loss 1.5354 | Train Acc 0.2109
Epoch 00000 | Step 00005 | Loss 1.4470 | Train Acc 0.2500
Epoch 00000 | Step 00010 | Loss 1.4314 | Train Acc 0.2969
Epoch 00000 | Step 00015 | Loss 1.3962 | Train Acc 0.3125
Epoch 00000 | Step 00020 | Loss 1.4737 | Train Acc 0.2109
Epoch 00001 | Step 00000 | Loss 1.3788 | Train Acc 0.3047
Epoch 00001 | Step 00005 | Loss 1.3628 | Train Acc 0.3125
Epoch 00001 | Step 00010 | Loss 1.3986 | Train Acc 0.2422
Epoch 00001 | Step 00015 | Loss 1.4299 | Train Acc 0.2422
Epoch 00001 | Step 00020 | Loss 1.4191 | Train Acc 0.2734
Epoch 00002 | Step 00000 | Loss 1.3411 | Train Acc 0.3828
Epoch 00002 | Step 00005 | Loss 1.3798 | Train Acc 0.2891
Epoch 00002 | Step 00010 | Loss 1.3671 | Train Acc 0.3594
Epoch 00002 | Step 00015 | Loss 1.4395 | Train Acc 0.2344
Epoch 00002 | Step 00020 | Loss 1.3408 | Train Acc 0.3906
Epoch 00003 | Step 00000 | Loss 1.3552 | Train Acc 0.3203
Epoch 00003 | Step 00005 | Loss 1.4390 | Train Acc 0.2344
Epoch 00003 | 