In [112]:
import torch as th
import numpy as np


class Euclidean:
    def __init__(self, max_norm=None, K=None, **kwargs):
        self.max_norm = max_norm
        self.K = K
        if K is not None:
            self.inner_radius = 2 * self.K / (1 + np.sqrt(1 + 4 * self.K * self.K))

    def init_weights(self, w, scale=1e-4):
        w.weight.data.uniform_(-scale, scale)

    def distance(self, u, v):
        return ((u - v).pow(2)).pow(1/2).sum(dim=-1)

from torch.nn import Module
from torch.nn import Embedding
import torch.nn.functional as fun
from torch import zeros


class Model(Module):
    def __init__(self, manifold, n, dim, sparse=False):
        super().__init__()
        self.manifold = manifold
        self.n = n
        self.dim = dim
        self.model = Embedding(n, dim, sparse=sparse)

    def forward(self, inputs):  # z inputs zrobi preds - z macierzy 10x52 zrobi sie 10x52xdim
        e = self.model(inputs)  # macierz 10x52xdim, to bedzie e to preds z pliku embed - tu wykonuje sie zanurzenie
        o = e.narrow(1, 1, e.size(1) - 1)   # macierz e bez 1 kolumny (10x51xdim)
        s = e.narrow(1, 0, 1).expand_as(o)  # macierz 10x51xdim - wiersz to powt. sie 1 kolumna z e
        dist = self.manifold.distance(s, o)
        return dist.squeeze(-1)

    def loss(self, inp):
        return fun.cross_entropy(inp.neg(), target=zeros())
    
import pandas as pd
import random


def load_graph(vertices_num):
    edges_set = {i + 1: set() for i in range(vertices_num)}
    edges_pd = pd.read_csv('0_changed.edges', header=None, delimiter=" ")

    for i in range(len(edges_pd)):
        if edges_pd[0][i] <= vertices_num and edges_pd[1][i] <= vertices_num:
            edges_set[edges_pd[0][i]].add(edges_pd[1][i])
            edges_set[edges_pd[1][i]].add(edges_pd[0][i])

    # edges_set[2].add(3)
    # sprawdzam poprawnosc - uzywajac powyzszej linijki tworze blad w danych
    for num1 in edges_set:
        for num2 in edges_set:
            if num2 in edges_set[num1] and num1 not in edges_set[num2]:
                raise ValueError("Blad w danych")

    # usuwam zbiory puste - po zmianie pliku zakladam ze ich nie ma
    for j in list(edges_set):
        if edges_set[j] == set():
            print("Ucinam niepolaczony wierzch nr: ", j)
            edges_set.pop(j)

    return edges_set


In [7]:
P_Eukl = Euclidean()
graph = load_graph(330)
from random import choice, randint

In [110]:
model_n330 = Model(P_Eukl, 330, 2)
model_n10 = Model(P_Eukl, 10, 2)
model_n520 = Model(P_Eukl, 520, 2)

In [111]:
# test
input_test = th.LongTensor([[1,2,3,4,5],[3,4,3,2,9]])
model_n10(input_test)  # dziala

tensor([[0.1301, 1.6150, 1.7790, 2.2804],
        [2.7392, 0.0000, 1.6344, 2.7656]], grad_fn=<SqueezeBackward1>)

In [103]:
# zmieniam th.zeros na th.rand
inputs = th.randint(330, (10,52))
for i in range(0, 10): 
    batch = i+1
    batch_connected = choice(tuple(graph[batch]))
    inputs[i, 0] = batch
    inputs[i, 1] = batch_connected
    for j in range(2, 52):  
        inputs[i, j] = randint(1, 330)
        while inputs[i, j].item() in graph[batch]:
            inputs[i, j] = randint(1, 330)

In [113]:
model_n520(inputs)

tensor([[1.9103, 1.1239, 0.4525, 1.6597, 1.7578, 1.3541, 1.0932, 1.2980, 2.7206,
         2.2948, 1.7337, 2.4548, 1.7046, 2.2676, 2.2948, 2.2676, 2.8214, 1.3031,
         0.4249, 3.6460, 1.9101, 2.6832, 2.1680, 0.3705, 0.8415, 1.9503, 2.9763,
         0.8685, 1.0664, 2.9355, 2.5749, 0.2611, 2.1939, 2.3237, 2.8502, 1.1023,
         2.4086, 3.2502, 1.8479, 1.2779, 2.0304, 1.8911, 2.3412, 1.0245, 1.3193,
         2.1194, 2.6501, 0.8439, 0.5808, 2.8502, 1.6055],
        [1.3925, 1.6917, 2.4082, 0.7898, 0.7097, 0.3200, 2.0130, 1.7219, 2.7803,
         3.2535, 1.3827, 0.8240, 2.4082, 1.8787, 0.7442, 0.2808, 2.2491, 0.9688,
         0.9951, 0.6821, 1.5274, 1.6874, 2.9505, 2.3207, 1.7736, 1.9696, 2.7995,
         2.1858, 2.1054, 3.0145, 0.8384, 2.6579, 0.4749, 0.9688, 1.7384, 1.5572,
         0.4694, 1.4998, 1.6410, 2.7995, 1.1651, 0.2996, 0.6249, 2.0130, 2.4566,
         1.3340, 2.1560, 2.2850, 0.9750, 1.9422, 2.4566],
        [2.3186, 2.8060, 1.4454, 1.7178, 1.4546, 3.7746, 1.1816, 1.4890, 2