In [None]:
from graph import Graph, Part
from typing import Set
from utils import get_ordering, get_splits

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from datetime import datetime

class LSTM(nn.Module):

    def __init__(self, input_size, hidden_size, output_size):
        super(LSTM, self).__init__()
        self.hidden_size = hidden_size
        self.lstm = nn.LSTM(input_size, hidden_size)
        self.hidden2tag = nn.Linear(hidden_size, output_size)

    def forward(self, sentence):
        lstm_out, _ = self.lstm(sentence.view(len(sentence), 1, -1))
        tag_space = self.hidden2tag(lstm_out.view(len(sentence), -1))
        tag_scores = F.log_softmax(tag_space, dim=1)
        return tag_scores

def integers_to_onehot(numbers, output_size):
    one_hot = torch.zeros(len(numbers), output_size)
    for idx, number in enumerate(numbers):
        one_hot[idx, number] = 1
    return one_hot

# Define hyperparameters
input_size = 1 # Value between 0 and 2270
hidden_size = 20
output_size = 20
model = LSTM(input_size, hidden_size, output_size)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

ordering = get_ordering()
splits = get_splits()
avgOrder_compatibleGraphs = ordering.get_compatible_graphs(splits["y_train"])

# Example training loop
for epoch in range(1000):
    total_loss = 0
    for graph, seq, positions in avgOrder_compatibleGraphs:

        input = torch.Tensor(list(map(lambda i: i/2270, seq)))
        target = integers_to_onehot(positions, output_size)

        optimizer.zero_grad()
        output = model(input)
        loss = criterion(output, target) # Output had a .view(1,-1) in the example - why?
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    if epoch % 100 == 0:
        print(f'{datetime.now().strftime("%H:%M:%S")}: Epoch {epoch}, Loss: {total_loss/10}')

In [None]:
def createGraph(unorderedParts: Set[Part], model: nn.Module):
    ordering = get_ordering()
    splits = get_splits()
    parts = ordering.sort(unorderedParts)
    input = torch.Tensor(list(map(lambda part: int(part.get_part_id())/2270, parts)))
    output_one_hot = model(input)
    output_positions = torch.argmax(output_one_hot, dim=1).tolist()
    g = Graph()
    for idx, pos in enumerate(output_positions):
        if(pos >= len(parts)):
            pos = 0
        g.add_undirected_edge(parts[idx], parts[pos])
    if sum(output_positions) > 0:
        print(f"Non-Zero Model output for: {unorderedParts}")
    return g


correct_counter = 0
for parts, graph in zip(splits["x_train"], splits["y_train"]):
    prediction = createGraph(parts, model)
    if prediction == graph:
        correct_counter += 1
print(correct_counter / len(splits["y_train"]))