In [None]:
import dgl
import torch.nn as nn
import dgl.nn as gnn
import torch


class RGCN(nn.Module):

    def __init__(self, in_feats, n_hidden, n_classes, n_layers, activation, dropout, rel_names):
        super().__init__()
        self.n_layers = n_layers
        self.n_hidden = n_hidden
        self.n_classes = n_classes
        self.layers = nn.ModuleList()

        if n_layers > 1:
            self.layers.append(gnn.HeteroGraphConv({
                rel: gnn.GraphConv(in_feats, n_hidden, norm='both')
                for rel in rel_names
            }, aggregate='sum'))
            for i in range(1, n_layers - 1):
                self.layers.append(gnn.HeteroGraphConv({
                    rel: gnn.GraphConv(n_hidden, n_hidden, norm='both')
                    for rel in rel_names
                }, aggregate='sum'))
            self.layers.append(gnn.HeteroGraphConv({
                rel: gnn.GraphConv(n_hidden, n_classes, norm='both')
                for rel in rel_names
            }, aggregate='sum'))

        else:
            self.layers.append(gnn.HeteroGraphConv({
                rel: gnn.GraphConv(in_feats, n_classes, norm='both')
                for rel in rel_names
            }, aggregate='sum'))
        self.dropout = nn.Dropout(dropout)
        self.activation = activation

    def forward(self, blocks, x):
        h = x
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
            h = layer(block, h)
            if l != len(self.layers) - 1:
                h = {k: self.activation(v) for k, v in h.items()}
                h = {k: self.dropout(v) for k, v in h.items()}
        return h

In [None]:
def get_embeddings(self, g):
    """
    Инференс без использования сэмплинга соседей
    """
    x = {'okved': g.ndata['features'].float()}
    for l, layer in enumerate(self.layers):
        y = {'okved': torch.zeros(g.number_of_nodes('okved'),
                                  self.n_hidden if l != len(self.layers) - 1 else self.n_classes)}

        sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
        dataloader = dgl.dataloading.NodeDataLoader(g,
                                                    {'okved': torch.arange(g.num_nodes()).to(g.device)},
                                                    sampler,
                                                    batch_size=g.num_nodes(),
                                                    shuffle=False,
                                                    drop_last=False,
                                                    num_workers=1)

        for input_nodes, output_nodes, blocks in dataloader:
            block = blocks[0].int().to(g.device)
            h = {'okved': x['okved'][input_nodes].to(g.device)}
            h = layer(block, h)
            if l != len(self.layers) - 1:
                h = {k: self.activation(v) for k, v in h.items()}
                h = {k: self.dropout(v) for k, v in h.items()}

            y['okved'][output_nodes] = h['okved'].cpu()
            x = у

    return y