In [None]:
from graph import Graph, Part
from typing import Set
from utils import get_ordering, get_splits

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from datetime import datetime
from tqdm import trange, tqdm

class LSTM(nn.Module):

    def __init__(self, input_size, hidden_size, output_size):
        super(LSTM, self).__init__()
        self.hidden_size = hidden_size
        self.lstm = nn.LSTM(input_size, hidden_size)
        self.hidden2tag = nn.Linear(hidden_size, output_size)

    def forward(self, sentence):
        lstm_out, _ = self.lstm(sentence.view(len(sentence), 1, -1))
        tag_space = self.hidden2tag(lstm_out.view(len(sentence), -1))
        tag_scores = F.log_softmax(tag_space, dim=1)
        return tag_scores

def integers_to_onehot(numbers, output_size, device):
    one_hot = torch.zeros(len(numbers), output_size, device=device)
    for idx, number in enumerate(numbers):
        one_hot[idx, number] = 1
    return one_hot

# Define hyperparameters
input_size = 1 # Value between 0 and 2270
hidden_size = 20
output_size = 20
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device="cpu"
model = LSTM(input_size, hidden_size, output_size)
model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

ordering = get_ordering()
splits = get_splits()
avgOrder_compatibleGraphs = ordering.get_compatible_graphs(splits["y_train"])

model.train()
# Example training loop
epochs = 1000
t = trange(epochs, desc='LSTM', leave=True)
for epoch in t:
    total_loss = 0
    for graph, seq, positions in avgOrder_compatibleGraphs:
        input = torch.tensor(list(map(lambda i: i/2270, seq)), device=device)
        target = integers_to_onehot(positions, output_size, device)
        
        optimizer.zero_grad()
        output = model(input)
        loss = criterion(output, target) # Output had a .view(1,-1) in the example - why?
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    t.set_description(f'{datetime.now().strftime("%H:%M:%S")}: Loss: {total_loss}')

torch.save(model.state_dict(), "./LSTM_model.pth")

In [None]:
def createGraph(unorderedParts: Set[Part], model: nn.Module):
    model.eval()
    ordering = get_ordering()
    parts = ordering.sort(unorderedParts)
    input = torch.tensor(list(map(lambda part: int(part.get_part_id())/2270, parts)), device=device)
    with torch.no_grad():
        output_one_hot = model(input)
    output_positions = torch.argmax(output_one_hot, dim=1).tolist()
    g = Graph()
    for idx, pos in enumerate(output_positions):
        if(pos >= len(parts)):
            pos = 0
        g.add_undirected_edge(parts[idx], parts[pos])
    #if sum(output_positions) > 0:
        #print(f"Non-Zero Model output for: {unorderedParts}")
    return g

splits = get_splits()
correct_counter = 0
for parts, graph in tqdm(zip(splits["x_train"], splits["y_train"]), total=len(splits["x_train"])):
    prediction = createGraph(parts, model)
    if prediction == graph:
        correct_counter += 1
correct_train = correct_counter / len(splits["y_train"])

correct_counter = 0
for parts, graph in tqdm(zip(splits["x_val"], splits["y_val"]), total=len(splits["x_val"])):
    prediction = createGraph(parts, model)
    if prediction == graph:
        correct_counter += 1
correct_val = correct_counter / len(splits["y_val"])

print(f"Acurracy[CompletelyCorrect] in Train:{correct_train}")
print(f"Acurracy[CompletelyCorrect] in Val:{correct_val}")