In [None]:
import torch.nn as nn
import torch
import dgl.nn as gnn


class SAGE(nn.Module):
    def __init__(self, in_feats, n_hidden, n_classes, n_layers,
                 activation, dropout,
                 freeze, embedding_tensor):
        super().__init__()
        self.n_layers = n_layers
        self.n_hidden = n_hidden
        self.n_classes = n_classes
        self.layers = nn.ModuleList()

        self.freeze = freeze
        self.embeddings = nn.Embedding.from_pretrained(embedding_tensor, freeze=self.freeze)

        if n_layers > 1:
            self.layers.append(
                gnn.SAGEConv(in_feats, n_hidden, aggregator_type='mean', activation=activation, feat_drop=dropout))
            for i in range(1, n_layers - 1):
                self.layers.append(
                    gnn.SAGEConv(n_hidden, n_hidden, aggregator_type='mean', activation=activation, feat_drop=dropout))
            self.layers.append(gnn.SAGEConv(n_hidden, n_classes, aggregator_type='mean'))
        else:
            self.layers.append(gnn.SAGEConv(in_feats, n_classes, aggregator_type='mean'))

        self.dropout = nn.Dropout(dropout)
        self.activation = activation

    def forward(self, g, x, okveds):
        # получаем эмбеддинги ОКВЭД, конкатенируем с остальными фичами узлов и прогоняем через сеть
        okved_embs = self.embeddings(okveds)
        h = torch.cat([x, okved_embs], dim=1)
        for l, (layer) in enumerate(self.layers):
            h = layer(g, h)
        return h