In [None]:
import dgl.function as fn
import torch.nn as nn
import torch
import torch.nn.functional as F


class HeteroLoss(nn.Module):
    """
    Базовый класс для функции потерь для предсказания связей на гетерографе
    """

    def __init__(self, train_on_gc=True):
        self.train_on_gc = train_on_gc
        super().__init__()

    def apply_edges(self, edges):
        raise NotImplementedError

    def forward(self, block_outputs, pos_graph, neg_graph):
        pos_scores = {}
        with pos_graph.local_scope():
            pos_graph.ndata['h'] = block_outputs['okved']
            for etype in pos_graph.canonical_etypes:
                pos_graph.apply_edges(self.apply_edges, etype=etype)
                pos_scores[etype] = pos_graph.edges[etype].data['score']

        neg_scores = {}
        with neg_graph.local_scope():
            neg_graph.ndata['h'] = block_outputs['okved']
            for etype in neg_graph.canonical_etypes:
                neg_graph.apply_edges(self.apply_edges, etype=etype)
                neg_scores[etype] = neg_graph.edges[etype].data['score']

        if self.train_on_gc:
            pos_score = pos_scores[('okved', 'gc', 'okved')]
            neg_score = neg_scores[('okved', 'gc', 'okved')]
        else:
            pos_score = torch.cat([v for v in pos_scores.values()])
            neg_score = torch.cat([v for v in neg_scores.values()])

        score = torch.cat([pos_score, neg_score], dim=0)
        label = torch.cat([torch.ones(len(pos_score)),
                           torch.zeros(len(neg_score))]).long()

        return score, label

In [None]:
class HeteroDMCELoss(HeteroLoss):
    """
    Функция потерь для предсказания связей в гетерографе с использованием скалярного произведения
    """
    output_number = True
    apply_edges = fn.u_dot_v('h', 'h', 'score')

    def __init__(self, emb_size, train_on_gc=True):
        super().__init__(train_ongc)
        self.emb_size = emb_size
        self.relation = nn.Parameter(torch.rand(emb_size))

    def apply_edges(self, edges):
        h_u = edges.src['h']
        h_v = edges.dst['h']
        r = self.relation.repeat((len(h_u), 1))
        score = torch.sum(h_u * r * h_v, dim=1).view(-1, 1)
        return {'score': score}

    def forward(self, block_outputs, pos_graph, neg_graph):
        score, label = super().forward(block_outputs, pos_graph, neg_graph)
        assert score.shape[1] == 1
        loss = F.binary_cross_entropy_with_logits(score.flatten(), label.float(), pos_weight=torch.LongTensor([5]))

        return loss, score, label

In [None]:
class HeteroMLPCELoss(nn.Module):
    """
    Функция потерь для предсказания связей в гетерографе с использованием полносвязного слоя
    """
    output_number = False

    def __init__(self, in_features, out_classes, node_type='okved', train_on_gc=True):
        super().__init__(train_on_gc)
        self.W = nn.Linear(in_features * 2, out_classes)

    def apply_edges(self, edges):
        h_u = edges.src['h']
        h_v = edges.dst['h']
        score = self.W(torch.cat([h_u, h_v], 1))
        return {'score': score}

    def forward(self, block_outputs, pos_graph, neg_graph):
        score, label = super().forward(block_outputs, pos_graph, neg_graph)
        loss = F.cross_entropy(score, label)
        return loss, score, label

In [None]:
class DotCeLossWith0kvedDistances(torch.nn.Module):
    """
    Функция потерь для бинарной классификации ребер с учетом эмбеддингов кодов ОКВЭД узлов
    """

    def __init__(self, okved_embeddings: torch.nn.Embedding, pos_weight, okved_impact: float = 1.0):
        super().＿init＿()
        self.okved_embeddings = okved_embeddings
        self.okved_impact = okved_impact
        self.pos_weight = pos_weight

    def apply_edges(self, edges):
        h_u = edges.src['h']
        h_v = edges.dst['h']
        o_u = edges.src['okved']
        o_v = edges.dst['okved']
        d_o = (self.okved_embeddings(o_u) - self.okved_embeddings(o_v)).pow(2).sum(dim=1).sqrt()
        dot = torch.sum(h_u * h_v, dim=1)
        score = dot.subtract(self.okved_impact * d_o).view(-1, 1)
        return {'score': score}

    def forward(self, g, h, okveds, labels, mask):
        with g.local_scope():
            g.ndata['h'] = h
            g.ndata['okved'] = okveds
            g.apply_edges(self.apply_edges)
            g.ndata.pop('h')
            g.ndata.pop('okved')
            score = g.edata.pop('score')

        assert score.shape[1] == 1
        score_masked = score.flatten()[mask]
        labels_masked = labels.float()[mask]
        loss = F.binary_cross_entropy_with_logits(score_masked, labels_masked, pos_weight=self.pos_weight)

        return loss, score_masked, labels_masked