# GNN for Parity Games Experiments

## Raw Training Data Creation

**07 July 2022**

- Num games: 3000
- Graph size (N) range: [10,200]
- Relative outdegree $(\frac{d_i}{N})$ range: [0.01,0.5]
- Priority range: [0,N]

In [1]:
num_graphs = 3000
min_n = 10
max_n = 200
min_rod = 0.01
max_rod = 0.5

games_dir = 'games'
solutions_dir = 'solutions'
pgsolver_base = 'pgsolver' # Path to compiled base dir of https://github.com/tcsprojects/pgsolver 

In [2]:
#import game_generator as gg

#gg.create_games_and_solutions(num_graphs, min_n, max_n, min_rod, max_rod, games_dir, solutions_dir, pgsolver_base)

## Prepare Training

In [2]:
from torch_geometric.loader import DataLoader
from parity_game_dataset import ParityGameDataset
import math

data = ParityGameDataset('pg_data_20220708', 'games', 'solutions')

# Use first 70% of the data for training
split_index = math.floor(0.7 * len(data))
train_data = data[:split_index]
test_data = data[split_index:]

train_loader = DataLoader(train_data, batch_size=1, shuffle=True)
test_loader = DataLoader(test_data, batch_size=1, shuffle=False)



**Training Parameters**

- optimizer: Adam
- Learning rate: 0.001
- Loss: Cross Entropy
- Epochs: 1-100

## Prepare model

**Parameters**

- Iterations: 10
- ...

In [3]:
from importlib import reload
import parity_game_network as pn
#from parity_game_network import ParityGameNetwork
reload(pn)
model = pn.ParityGameNetwork(256, 256, 10)
#model = pn.ParityGameNetwork(128, 128, 5)

In [5]:
model

ParityGameNetwork(
  (core): GAT(3, 256, num_layers=10)
  (node_classifier): Sequential(
    (0): Linear(in_features=256, out_features=256, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.2, inplace=False)
    (3): Linear(in_features=256, out_features=2, bias=True)
    (4): Softmax(dim=1)
  )
  (edge_classifier): Sequential(
    (0): Linear(in_features=512, out_features=256, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.2, inplace=False)
    (3): Linear(in_features=256, out_features=2, bias=True)
    (4): Softmax(dim=1)
  )
)

# Run Training Loop

In [6]:
import wandb
wandb.init(project='gnn_parity_game_solver')
config = wandb.config
config.learning_rate = 0.001

wandb.watch(model, log='all')

[34m[1mwandb[0m: Currently logged in as: [33malexdweinert[0m. Use [1m`wandb login --relogin`[0m to force relogin


[]

In [7]:
import torch
import numpy as np
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion_nodes = torch.nn.CrossEntropyLoss()
criterion_edges = torch.nn.CrossEntropyLoss(weight=torch.tensor([0.1,0.9]))

def train():
    running_loss = 0.
    i = 0
    model.train()

    for data in train_loader:  # Iterate in batches over the training dataset.
        i += 1
        optimizer.zero_grad()  # Clear gradients.
        out_nodes, out_edges = model(data.x, data.edge_index)  # Perform a single forward pass.
        
        # Most edges do not belong to a winning strategy and thus the data is extemely imbalanced. The model will probably learn that predicting always "non-winning" for each edge 
        # yields reasonable performance. To avoid this, approximately as many non-winning strategy edges are sampled as there are winning edges.
        #edge_selection = (torch.rand(data.y_edges.shape[0]) > 0.7) | (data.y_edges == 1)
        loss = criterion_nodes(out_nodes, data.y_nodes) + criterion_edges(out_edges, data.y_edges) # Compute the loss.
        loss.backward()  # Derive gradients.
        optimizer.step()  # Update parameters based on gradients.
        
        running_loss += loss.item()
        if i % 50 == 0:
            last_loss = running_loss / 50 # loss per batch
            wandb.log({'loss': last_loss, 'variance': torch.var(out_nodes[:,1])})
            running_loss = 0.
            
def test(loader):
     model.eval()

     correct_nodes = 0
     correct_edges = 0

     for data in loader:  # Iterate in batches over the training/test dataset.
        out_nodes, out_edges = model(data.x, data.edge_index)  
        pred_nodes = out_nodes.argmax(dim=1)  # Use the class with highest probability.
        pred_edges = out_edges.argmax(dim=1)  # Use the class with highest probability.
        correct_nodes += (pred_nodes == data.y_nodes).sum() / len(pred_nodes)  # Check against ground-truth labels.
        correct_edges += (pred_edges == data.y_edges).sum() / len(pred_edges)
     return (correct_nodes / len(loader), correct_edges / len(loader))  # Derive ratio of correct predictions.

for epoch in range(1, 2):
    train()
    train_acc_nodes, train_acc_edges = test(train_loader)
    test_acc_nodes, test_acc_edges = test(test_loader)
    wandb.log({'Epoch': epoch, 'Acc_Acc_Nodes': train_acc_nodes, 'Test_Acc_Nodes': test_acc_nodes, 'Train_Acc_Edges': train_acc_edges, 'Test_Acc_Edges': test_acc_edges}) 
    print(f'Epoch: {epoch:03d}, Train Acc Nodes: {train_acc_nodes:.4f}, Test Acc Nodes: {test_acc_nodes:.4f}, Train Acc Edges: {train_acc_edges:.4f}, Test Acc Edges: {test_acc_edges:.4f}')

Epoch: 001, Train Acc Nodes: 0.9665, Test Acc Nodes: 0.9620, Train Acc Edges: 0.9092, Test Acc Edges: 0.9080


*Saving the following model:*
- Performance: Epoch: 001, Train Acc Nodes: 0.9832, Test Acc Nodes: 0.9824, Train Acc Edges: 0.9687, Test Acc Edges: 0.9685

In [36]:
print(model)

ParityGameNetwork(
  (core): GAT(3, 256, num_layers=10)
  (node_classifier): Sequential(
    (0): Linear(in_features=256, out_features=256, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.2, inplace=False)
    (3): Linear(in_features=256, out_features=2, bias=True)
    (4): Softmax(dim=1)
  )
  (edge_classifier): Sequential(
    (0): Linear(in_features=512, out_features=256, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.2, inplace=False)
    (3): Linear(in_features=256, out_features=2, bias=True)
    (4): Softmax(dim=1)
  )
)


In [10]:
torch.save(model.state_dict(), 'GAT_pg_solver_20220914.pth')

In [33]:
import torch
import numpy as np
model.load_state_dict(torch.load('GAT_pg_solver_20220914.pth'), strict=True)

<All keys matched successfully>

## Unparsing predicted strategies

In [13]:
it = iter(test_loader)
game = next(it)

import os

os.listdir('games')[split_index:]

['game_2889.txt',
 'game_289.txt',
 'game_2890.txt',
 'game_2891.txt',
 'game_2892.txt',
 'game_2893.txt',
 'game_2894.txt',
 'game_2895.txt',
 'game_2896.txt',
 'game_2897.txt',
 'game_2898.txt',
 'game_2899.txt',
 'game_29.txt',
 'game_290.txt',
 'game_2900.txt',
 'game_2901.txt',
 'game_2902.txt',
 'game_2903.txt',
 'game_2904.txt',
 'game_2905.txt',
 'game_2906.txt',
 'game_2907.txt',
 'game_2908.txt',
 'game_2909.txt',
 'game_291.txt',
 'game_2910.txt',
 'game_2911.txt',
 'game_2912.txt',
 'game_2913.txt',
 'game_2914.txt',
 'game_2915.txt',
 'game_2916.txt',
 'game_2917.txt',
 'game_2918.txt',
 'game_2919.txt',
 'game_292.txt',
 'game_2920.txt',
 'game_2921.txt',
 'game_2922.txt',
 'game_2923.txt',
 'game_2924.txt',
 'game_2925.txt',
 'game_2926.txt',
 'game_2927.txt',
 'game_2928.txt',
 'game_2929.txt',
 'game_293.txt',
 'game_2930.txt',
 'game_2931.txt',
 'game_2932.txt',
 'game_2933.txt',
 'game_2934.txt',
 'game_2935.txt',
 'game_2936.txt',
 'game_2937.txt',
 'game_2938.txt',

In [39]:
import pg_parser as parser
import os
import time

game_list = os.listdir('games')

game_it = iter(test_loader)
for i in range(split_index, len(data)):
    
    game = next(game_it)
    
    timestamp = time.time()
    out_nodes, out_edges = model(game.x, game.edge_index)
    duration = time.time() - timestamp
    
    print("Model application took " + str(duration) + " seconds")
    winning_edges_tensor = out_edges.argmax(dim=1)
    print("Picked " + str(torch.count_nonzero(winning_edges_tensor)) + " edges")
    continue
    
    winning_edge_index = [None] * len(game.x)
    
    for edge_index in range(len(game.y_edges)):
        if winning_edges_tensor[edge_index] == 0:
            continue
            
        start_vertex = game.edge_index[0][edge_index]
            
        if not winning_edge_index[start_vertex] == None:
            print("oh no")
            
        winning_edge_index[start_vertex] = edge_index
    
    for vertex in range(len(winning_edge_index)):
        if winning_edge_index[vertex] == None:
            print("Could not determine winning edge for vertex " + str(vertex))
            print("Vertex belongs to player " + str(0 if game.x[vertex][1] > game.x[vertex][2] else 1))
    
    solution_path = "generated_solutions/gensolution_" + game_list[i]
    with open(solution_path, 'w') as output_file:
        print(f"parity {len(game.x)};")
        for start_node in range(len(winning_edge_index)):
            owner = 0 if game.x[start_node][1].item() else 1
            next_node = game.edge_index[1][winning_edge_index[start_node]]
            output_file.write(f"\n{start_node} {owner} {next_node};")
    
    ++i

DataBatch(x=[142, 3], edge_index=[2, 5176], y_nodes=[142], y_edges=[5176], batch=[142], ptr=[2])
Model application took 0.13241004943847656 seconds
Picked tensor(0) edges
DataBatch(x=[14, 3], edge_index=[2, 72], y_nodes=[14], y_edges=[72], batch=[14], ptr=[2])
Model application took 0.03391432762145996 seconds
Picked tensor(0) edges
DataBatch(x=[95, 3], edge_index=[2, 2208], y_nodes=[95], y_edges=[2208], batch=[95], ptr=[2])
Model application took 0.10423970222473145 seconds
Picked tensor(0) edges
DataBatch(x=[58, 3], edge_index=[2, 920], y_nodes=[58], y_edges=[920], batch=[58], ptr=[2])
Model application took 0.08709049224853516 seconds
Picked tensor(0) edges
DataBatch(x=[170, 3], edge_index=[2, 7422], y_nodes=[170], y_edges=[7422], batch=[170], ptr=[2])
Model application took 0.16501092910766602 seconds
Picked tensor(0) edges
DataBatch(x=[191, 3], edge_index=[2, 9484], y_nodes=[191], y_edges=[9484], batch=[191], ptr=[2])
Model application took 0.16856813430786133 seconds
Picked tenso

Model application took 0.12027192115783691 seconds
Picked tensor(0) edges
DataBatch(x=[194, 3], edge_index=[2, 9639], y_nodes=[194], y_edges=[9639], batch=[194], ptr=[2])
Model application took 0.171630859375 seconds
Picked tensor(0) edges
DataBatch(x=[105, 3], edge_index=[2, 2920], y_nodes=[105], y_edges=[2920], batch=[105], ptr=[2])
Model application took 0.10576033592224121 seconds
Picked tensor(0) edges
DataBatch(x=[178, 3], edge_index=[2, 8395], y_nodes=[178], y_edges=[8395], batch=[178], ptr=[2])
Model application took 0.1612095832824707 seconds
Picked tensor(0) edges
DataBatch(x=[141, 3], edge_index=[2, 4795], y_nodes=[141], y_edges=[4795], batch=[141], ptr=[2])
Model application took 0.1049046516418457 seconds
Picked tensor(0) edges
DataBatch(x=[140, 3], edge_index=[2, 4670], y_nodes=[140], y_edges=[4670], batch=[140], ptr=[2])
Model application took 0.11248230934143066 seconds
Picked tensor(0) edges
DataBatch(x=[86, 3], edge_index=[2, 1753], y_nodes=[86], y_edges=[1753], batch

Model application took 0.10180783271789551 seconds
Picked tensor(0) edges
DataBatch(x=[57, 3], edge_index=[2, 839], y_nodes=[57], y_edges=[839], batch=[57], ptr=[2])
Model application took 0.06971168518066406 seconds
Picked tensor(0) edges
DataBatch(x=[178, 3], edge_index=[2, 8396], y_nodes=[178], y_edges=[8396], batch=[178], ptr=[2])
Model application took 0.1528916358947754 seconds
Picked tensor(0) edges
DataBatch(x=[32, 3], edge_index=[2, 227], y_nodes=[32], y_edges=[227], batch=[32], ptr=[2])
Model application took 0.0426173210144043 seconds
Picked tensor(0) edges
DataBatch(x=[13, 3], edge_index=[2, 57], y_nodes=[13], y_edges=[57], batch=[13], ptr=[2])
Model application took 0.03431415557861328 seconds
Picked tensor(0) edges
DataBatch(x=[37, 3], edge_index=[2, 365], y_nodes=[37], y_edges=[365], batch=[37], ptr=[2])
Model application took 0.047470808029174805 seconds
Picked tensor(0) edges
DataBatch(x=[93, 3], edge_index=[2, 2270], y_nodes=[93], y_edges=[2270], batch=[93], ptr=[2])


Model application took 0.16327548027038574 seconds
Picked tensor(0) edges
DataBatch(x=[109, 3], edge_index=[2, 3080], y_nodes=[109], y_edges=[3080], batch=[109], ptr=[2])
Model application took 0.08666419982910156 seconds
Picked tensor(0) edges
DataBatch(x=[118, 3], edge_index=[2, 3593], y_nodes=[118], y_edges=[3593], batch=[118], ptr=[2])
Model application took 0.09661531448364258 seconds
Picked tensor(0) edges
DataBatch(x=[36, 3], edge_index=[2, 340], y_nodes=[36], y_edges=[340], batch=[36], ptr=[2])
Model application took 0.044432878494262695 seconds
Picked tensor(0) edges
DataBatch(x=[139, 3], edge_index=[2, 4628], y_nodes=[139], y_edges=[4628], batch=[139], ptr=[2])
Model application took 0.11955046653747559 seconds
Picked tensor(0) edges
DataBatch(x=[80, 3], edge_index=[2, 1501], y_nodes=[80], y_edges=[1501], batch=[80], ptr=[2])
Model application took 0.07458209991455078 seconds
Picked tensor(0) edges
DataBatch(x=[78, 3], edge_index=[2, 1572], y_nodes=[78], y_edges=[1572], batch

Model application took 0.11922883987426758 seconds
Picked tensor(0) edges
DataBatch(x=[100, 3], edge_index=[2, 2485], y_nodes=[100], y_edges=[2485], batch=[100], ptr=[2])
Model application took 0.08928775787353516 seconds
Picked tensor(0) edges
DataBatch(x=[148, 3], edge_index=[2, 5587], y_nodes=[148], y_edges=[5587], batch=[148], ptr=[2])
Model application took 0.11599135398864746 seconds
Picked tensor(0) edges
DataBatch(x=[188, 3], edge_index=[2, 8841], y_nodes=[188], y_edges=[8841], batch=[188], ptr=[2])
Model application took 0.16143107414245605 seconds
Picked tensor(0) edges
DataBatch(x=[49, 3], edge_index=[2, 702], y_nodes=[49], y_edges=[702], batch=[49], ptr=[2])
Model application took 0.050571441650390625 seconds
Picked tensor(0) edges
DataBatch(x=[147, 3], edge_index=[2, 5701], y_nodes=[147], y_edges=[5701], batch=[147], ptr=[2])
Model application took 0.1282210350036621 seconds
Picked tensor(0) edges
DataBatch(x=[26, 3], edge_index=[2, 223], y_nodes=[26], y_edges=[223], batch

Model application took 0.16311359405517578 seconds
Picked tensor(0) edges
DataBatch(x=[134, 3], edge_index=[2, 4938], y_nodes=[134], y_edges=[4938], batch=[134], ptr=[2])
Model application took 0.09978342056274414 seconds
Picked tensor(0) edges
DataBatch(x=[54, 3], edge_index=[2, 801], y_nodes=[54], y_edges=[801], batch=[54], ptr=[2])
Model application took 0.06462645530700684 seconds
Picked tensor(0) edges
DataBatch(x=[194, 3], edge_index=[2, 9358], y_nodes=[194], y_edges=[9358], batch=[194], ptr=[2])
Model application took 0.16232514381408691 seconds
Picked tensor(0) edges
DataBatch(x=[192, 3], edge_index=[2, 8744], y_nodes=[192], y_edges=[8744], batch=[192], ptr=[2])
Model application took 0.15796875953674316 seconds
Picked tensor(0) edges
DataBatch(x=[123, 3], edge_index=[2, 4124], y_nodes=[123], y_edges=[4124], batch=[123], ptr=[2])
Model application took 0.10674881935119629 seconds
Picked tensor(0) edges
DataBatch(x=[116, 3], edge_index=[2, 3571], y_nodes=[116], y_edges=[3571], b

Model application took 0.026556968688964844 seconds
Picked tensor(0) edges
DataBatch(x=[157, 3], edge_index=[2, 6188], y_nodes=[157], y_edges=[6188], batch=[157], ptr=[2])
Model application took 0.13948750495910645 seconds
Picked tensor(0) edges
DataBatch(x=[52, 3], edge_index=[2, 639], y_nodes=[52], y_edges=[639], batch=[52], ptr=[2])
Model application took 0.0656583309173584 seconds
Picked tensor(0) edges
DataBatch(x=[72, 3], edge_index=[2, 1502], y_nodes=[72], y_edges=[1502], batch=[72], ptr=[2])
Model application took 0.07752823829650879 seconds
Picked tensor(0) edges
DataBatch(x=[82, 3], edge_index=[2, 1733], y_nodes=[82], y_edges=[1733], batch=[82], ptr=[2])
Model application took 0.07513737678527832 seconds
Picked tensor(0) edges
DataBatch(x=[156, 3], edge_index=[2, 6681], y_nodes=[156], y_edges=[6681], batch=[156], ptr=[2])
Model application took 0.12672924995422363 seconds
Picked tensor(0) edges
DataBatch(x=[169, 3], edge_index=[2, 8007], y_nodes=[169], y_edges=[8007], batch=[

Model application took 0.08266448974609375 seconds
Picked tensor(0) edges
DataBatch(x=[58, 3], edge_index=[2, 853], y_nodes=[58], y_edges=[853], batch=[58], ptr=[2])
Model application took 0.06634163856506348 seconds
Picked tensor(0) edges
DataBatch(x=[145, 3], edge_index=[2, 4808], y_nodes=[145], y_edges=[4808], batch=[145], ptr=[2])
Model application took 0.10767960548400879 seconds
Picked tensor(0) edges
DataBatch(x=[47, 3], edge_index=[2, 509], y_nodes=[47], y_edges=[509], batch=[47], ptr=[2])
Model application took 0.050055503845214844 seconds
Picked tensor(0) edges
DataBatch(x=[193, 3], edge_index=[2, 9630], y_nodes=[193], y_edges=[9630], batch=[193], ptr=[2])
Model application took 0.1642918586730957 seconds
Picked tensor(0) edges
DataBatch(x=[30, 3], edge_index=[2, 229], y_nodes=[30], y_edges=[229], batch=[30], ptr=[2])
Model application took 0.04012012481689453 seconds
Picked tensor(0) edges
DataBatch(x=[90, 3], edge_index=[2, 2211], y_nodes=[90], y_edges=[2211], batch=[90], p

Picked tensor(0) edges
DataBatch(x=[74, 3], edge_index=[2, 1313], y_nodes=[74], y_edges=[1313], batch=[74], ptr=[2])
Model application took 0.07796430587768555 seconds
Picked tensor(0) edges
DataBatch(x=[121, 3], edge_index=[2, 3578], y_nodes=[121], y_edges=[3578], batch=[121], ptr=[2])
Model application took 0.09171533584594727 seconds
Picked tensor(0) edges
DataBatch(x=[145, 3], edge_index=[2, 5134], y_nodes=[145], y_edges=[5134], batch=[145], ptr=[2])
Model application took 0.10788416862487793 seconds
Picked tensor(0) edges
DataBatch(x=[42, 3], edge_index=[2, 456], y_nodes=[42], y_edges=[456], batch=[42], ptr=[2])
Model application took 0.04520392417907715 seconds
Picked tensor(0) edges
DataBatch(x=[50, 3], edge_index=[2, 652], y_nodes=[50], y_edges=[652], batch=[50], ptr=[2])
Model application took 0.07165789604187012 seconds
Picked tensor(0) edges
DataBatch(x=[15, 3], edge_index=[2, 63], y_nodes=[15], y_edges=[63], batch=[15], ptr=[2])
Model application took 0.033356666564941406 s

Model application took 0.1259143352508545 seconds
Picked tensor(0) edges
DataBatch(x=[113, 3], edge_index=[2, 3462], y_nodes=[113], y_edges=[3462], batch=[113], ptr=[2])
Model application took 0.09836030006408691 seconds
Picked tensor(0) edges
DataBatch(x=[149, 3], edge_index=[2, 5668], y_nodes=[149], y_edges=[5668], batch=[149], ptr=[2])
Model application took 0.137556791305542 seconds
Picked tensor(0) edges
DataBatch(x=[45, 3], edge_index=[2, 497], y_nodes=[45], y_edges=[497], batch=[45], ptr=[2])
Model application took 0.04375958442687988 seconds
Picked tensor(0) edges
DataBatch(x=[54, 3], edge_index=[2, 761], y_nodes=[54], y_edges=[761], batch=[54], ptr=[2])
Model application took 0.0599367618560791 seconds
Picked tensor(0) edges
DataBatch(x=[110, 3], edge_index=[2, 2995], y_nodes=[110], y_edges=[2995], batch=[110], ptr=[2])
Model application took 0.08783698081970215 seconds
Picked tensor(0) edges
DataBatch(x=[30, 3], edge_index=[2, 208], y_nodes=[30], y_edges=[208], batch=[30], pt

Model application took 0.1753220558166504 seconds
Picked tensor(0) edges
DataBatch(x=[191, 3], edge_index=[2, 9360], y_nodes=[191], y_edges=[9360], batch=[191], ptr=[2])
Model application took 0.16189360618591309 seconds
Picked tensor(0) edges
DataBatch(x=[140, 3], edge_index=[2, 4966], y_nodes=[140], y_edges=[4966], batch=[140], ptr=[2])
Model application took 0.10369062423706055 seconds
Picked tensor(0) edges
DataBatch(x=[27, 3], edge_index=[2, 229], y_nodes=[27], y_edges=[229], batch=[27], ptr=[2])
Model application took 0.035723209381103516 seconds
Picked tensor(0) edges
DataBatch(x=[39, 3], edge_index=[2, 418], y_nodes=[39], y_edges=[418], batch=[39], ptr=[2])
Model application took 0.04170393943786621 seconds
Picked tensor(0) edges
DataBatch(x=[107, 3], edge_index=[2, 2928], y_nodes=[107], y_edges=[2928], batch=[107], ptr=[2])
Model application took 0.08394503593444824 seconds
Picked tensor(0) edges
DataBatch(x=[194, 3], edge_index=[2, 9902], y_nodes=[194], y_edges=[9902], batch=

Model application took 0.1054072380065918 seconds
Picked tensor(0) edges
DataBatch(x=[36, 3], edge_index=[2, 316], y_nodes=[36], y_edges=[316], batch=[36], ptr=[2])
Model application took 0.04019570350646973 seconds
Picked tensor(0) edges
DataBatch(x=[139, 3], edge_index=[2, 4738], y_nodes=[139], y_edges=[4738], batch=[139], ptr=[2])
Model application took 0.11466407775878906 seconds
Picked tensor(0) edges
DataBatch(x=[112, 3], edge_index=[2, 3152], y_nodes=[112], y_edges=[3152], batch=[112], ptr=[2])
Model application took 0.102020263671875 seconds
Picked tensor(0) edges
DataBatch(x=[72, 3], edge_index=[2, 1346], y_nodes=[72], y_edges=[1346], batch=[72], ptr=[2])
Model application took 0.07469868659973145 seconds
Picked tensor(0) edges
DataBatch(x=[15, 3], edge_index=[2, 68], y_nodes=[15], y_edges=[68], batch=[15], ptr=[2])
Model application took 0.032784223556518555 seconds
Picked tensor(0) edges
DataBatch(x=[144, 3], edge_index=[2, 5298], y_nodes=[144], y_edges=[5298], batch=[144], 

Model application took 0.05406641960144043 seconds
Picked tensor(0) edges
DataBatch(x=[185, 3], edge_index=[2, 8662], y_nodes=[185], y_edges=[8662], batch=[185], ptr=[2])
Model application took 0.15699148178100586 seconds
Picked tensor(0) edges
DataBatch(x=[146, 3], edge_index=[2, 5562], y_nodes=[146], y_edges=[5562], batch=[146], ptr=[2])
Model application took 0.11878180503845215 seconds
Picked tensor(0) edges
DataBatch(x=[109, 3], edge_index=[2, 2730], y_nodes=[109], y_edges=[2730], batch=[109], ptr=[2])
Model application took 0.0989522933959961 seconds
Picked tensor(0) edges
DataBatch(x=[13, 3], edge_index=[2, 50], y_nodes=[13], y_edges=[50], batch=[13], ptr=[2])
Model application took 0.03634929656982422 seconds
Picked tensor(0) edges
DataBatch(x=[147, 3], edge_index=[2, 5721], y_nodes=[147], y_edges=[5721], batch=[147], ptr=[2])
Model application took 0.11386871337890625 seconds
Picked tensor(0) edges
DataBatch(x=[187, 3], edge_index=[2, 8811], y_nodes=[187], y_edges=[8811], batc

Model application took 0.11761331558227539 seconds
Picked tensor(0) edges
DataBatch(x=[77, 3], edge_index=[2, 1513], y_nodes=[77], y_edges=[1513], batch=[77], ptr=[2])
Model application took 0.07726716995239258 seconds
Picked tensor(0) edges
DataBatch(x=[12, 3], edge_index=[2, 40], y_nodes=[12], y_edges=[40], batch=[12], ptr=[2])
Model application took 0.029071807861328125 seconds
Picked tensor(0) edges
DataBatch(x=[33, 3], edge_index=[2, 294], y_nodes=[33], y_edges=[294], batch=[33], ptr=[2])
Model application took 0.040956974029541016 seconds
Picked tensor(0) edges
DataBatch(x=[75, 3], edge_index=[2, 1533], y_nodes=[75], y_edges=[1533], batch=[75], ptr=[2])
Model application took 0.07320427894592285 seconds
Picked tensor(0) edges
DataBatch(x=[102, 3], edge_index=[2, 2592], y_nodes=[102], y_edges=[2592], batch=[102], ptr=[2])
Model application took 0.10894894599914551 seconds
Picked tensor(0) edges
DataBatch(x=[144, 3], edge_index=[2, 5191], y_nodes=[144], y_edges=[5191], batch=[144],

Model application took 0.03777647018432617 seconds
Picked tensor(0) edges
DataBatch(x=[196, 3], edge_index=[2, 10116], y_nodes=[196], y_edges=[10116], batch=[196], ptr=[2])
Model application took 0.16561055183410645 seconds
Picked tensor(0) edges
DataBatch(x=[94, 3], edge_index=[2, 2294], y_nodes=[94], y_edges=[2294], batch=[94], ptr=[2])
Model application took 0.08409476280212402 seconds
Picked tensor(0) edges
DataBatch(x=[62, 3], edge_index=[2, 1036], y_nodes=[62], y_edges=[1036], batch=[62], ptr=[2])
Model application took 0.06471133232116699 seconds
Picked tensor(0) edges
DataBatch(x=[116, 3], edge_index=[2, 3380], y_nodes=[116], y_edges=[3380], batch=[116], ptr=[2])
Model application took 0.09303951263427734 seconds
Picked tensor(0) edges
DataBatch(x=[90, 3], edge_index=[2, 1919], y_nodes=[90], y_edges=[1919], batch=[90], ptr=[2])
Model application took 0.08045196533203125 seconds
Picked tensor(0) edges
DataBatch(x=[112, 3], edge_index=[2, 3003], y_nodes=[112], y_edges=[3003], bat

Model application took 0.10737919807434082 seconds
Picked tensor(0) edges
DataBatch(x=[158, 3], edge_index=[2, 5924], y_nodes=[158], y_edges=[5924], batch=[158], ptr=[2])
Model application took 0.12448644638061523 seconds
Picked tensor(0) edges
DataBatch(x=[120, 3], edge_index=[2, 3937], y_nodes=[120], y_edges=[3937], batch=[120], ptr=[2])
Model application took 0.09470081329345703 seconds
Picked tensor(0) edges
DataBatch(x=[82, 3], edge_index=[2, 1701], y_nodes=[82], y_edges=[1701], batch=[82], ptr=[2])
Model application took 0.0702812671661377 seconds
Picked tensor(0) edges
DataBatch(x=[11, 3], edge_index=[2, 36], y_nodes=[11], y_edges=[36], batch=[11], ptr=[2])
Model application took 0.028365612030029297 seconds
Picked tensor(0) edges
DataBatch(x=[143, 3], edge_index=[2, 5244], y_nodes=[143], y_edges=[5244], batch=[143], ptr=[2])
Model application took 0.11352992057800293 seconds
Picked tensor(0) edges
DataBatch(x=[129, 3], edge_index=[2, 4237], y_nodes=[129], y_edges=[4237], batch=

Model application took 0.06820559501647949 seconds
Picked tensor(0) edges
DataBatch(x=[39, 3], edge_index=[2, 423], y_nodes=[39], y_edges=[423], batch=[39], ptr=[2])
Model application took 0.048584699630737305 seconds
Picked tensor(0) edges
DataBatch(x=[22, 3], edge_index=[2, 129], y_nodes=[22], y_edges=[129], batch=[22], ptr=[2])
Model application took 0.036487579345703125 seconds
Picked tensor(0) edges
DataBatch(x=[174, 3], edge_index=[2, 7264], y_nodes=[174], y_edges=[7264], batch=[174], ptr=[2])
Model application took 0.15448355674743652 seconds
Picked tensor(0) edges
DataBatch(x=[61, 3], edge_index=[2, 906], y_nodes=[61], y_edges=[906], batch=[61], ptr=[2])
Model application took 0.0735015869140625 seconds
Picked tensor(0) edges
DataBatch(x=[144, 3], edge_index=[2, 5048], y_nodes=[144], y_edges=[5048], batch=[144], ptr=[2])
Model application took 0.10300469398498535 seconds
Picked tensor(0) edges
DataBatch(x=[154, 3], edge_index=[2, 6221], y_nodes=[154], y_edges=[6221], batch=[154

Model application took 0.17263221740722656 seconds
Picked tensor(0) edges
DataBatch(x=[17, 3], edge_index=[2, 73], y_nodes=[17], y_edges=[73], batch=[17], ptr=[2])
Model application took 0.03606247901916504 seconds
Picked tensor(0) edges
DataBatch(x=[87, 3], edge_index=[2, 2004], y_nodes=[87], y_edges=[2004], batch=[87], ptr=[2])
Model application took 0.08427286148071289 seconds
Picked tensor(0) edges
DataBatch(x=[37, 3], edge_index=[2, 354], y_nodes=[37], y_edges=[354], batch=[37], ptr=[2])
Model application took 0.04058670997619629 seconds
Picked tensor(0) edges
DataBatch(x=[152, 3], edge_index=[2, 5953], y_nodes=[152], y_edges=[5953], batch=[152], ptr=[2])
Model application took 0.10781002044677734 seconds
Picked tensor(0) edges
DataBatch(x=[106, 3], edge_index=[2, 3240], y_nodes=[106], y_edges=[3240], batch=[106], ptr=[2])
Model application took 0.09854984283447266 seconds
Picked tensor(0) edges
DataBatch(x=[80, 3], edge_index=[2, 1583], y_nodes=[80], y_edges=[1583], batch=[80], p

Model application took 0.16105365753173828 seconds
Picked tensor(0) edges
DataBatch(x=[37, 3], edge_index=[2, 363], y_nodes=[37], y_edges=[363], batch=[37], ptr=[2])
Model application took 0.04307103157043457 seconds
Picked tensor(0) edges
DataBatch(x=[67, 3], edge_index=[2, 1153], y_nodes=[67], y_edges=[1153], batch=[67], ptr=[2])
Model application took 0.0777273178100586 seconds
Picked tensor(0) edges
DataBatch(x=[64, 3], edge_index=[2, 1184], y_nodes=[64], y_edges=[1184], batch=[64], ptr=[2])
Model application took 0.06991076469421387 seconds
Picked tensor(0) edges
DataBatch(x=[17, 3], edge_index=[2, 68], y_nodes=[17], y_edges=[68], batch=[17], ptr=[2])
Model application took 0.03610348701477051 seconds
Picked tensor(0) edges


In [9]:
it = iter(test_loader)
game = next(it)

print(model(game.x, game.edge_index))

def get_index_of_winning_edge(tensor):
    max_weight = None
    for i in range(len(tensor)):
        if tensor[i].item() == 1:
            return i
        
    return 0

print(f"parity {len(game.x)};")

winning_edge_index = [None] * len(game.x)

for edge_index in range(len(game.y_edges)):
    start_node = game.edge_index[0][edge_index]
    current_winning_edge_index = winning_edge_index[start_node]
    
    if current_winning_edge_index == None:
        winning_edge_index[start_node] = edge_index
        continue
    
    if game.y_edges[edge_index].item() > game.y_edges[current_winning_edge_index].item():
        winning_edge_index[start_node] = edge_index
        
for start_node in range(len(winning_edge_index)):
    owner = 0 if game.x[start_node][1].item() else 1
    next_node = game.edge_index[1][winning_edge_index[start_node]]
    print(f"{start_node} {owner} {next_node};")

parity 142;
0 1 2;
1 1 60;
2 1 2;
3 1 19;
4 0 15;
5 0 12;
6 0 50;
7 0 7;
8 0 8;
9 1 22;
10 1 22;
11 0 7;
12 0 26;
13 1 9;
14 1 19;
15 0 15;
16 0 18;
17 0 7;
18 0 7;
19 1 2;
20 0 17;
21 0 4;
22 1 19;
23 1 22;
24 1 2;
25 1 9;
26 0 26;
27 1 28;
28 1 2;
29 1 71;
30 1 2;
31 1 28;
32 1 2;
33 0 8;
34 0 26;
35 0 8;
36 0 48;
37 1 0;
38 0 7;
39 1 28;
40 0 8;
41 1 0;
42 1 24;
43 1 29;
44 0 7;
45 1 3;
46 0 8;
47 1 23;
48 0 8;
49 1 2;
50 0 35;
51 0 11;
52 1 52;
53 1 24;
54 0 8;
55 0 8;
56 0 7;
57 1 57;
58 0 7;
59 0 7;
60 1 2;
61 0 61;
62 0 62;
63 1 63;
64 1 37;
65 1 13;
66 1 0;
67 1 23;
68 0 7;
69 0 15;
70 0 38;
71 1 71;
72 1 2;
73 0 7;
74 1 3;
75 0 7;
76 0 8;
77 1 2;
78 1 78;
79 1 79;
80 1 53;
81 0 11;
82 1 52;
83 1 0;
84 1 0;
85 0 11;
86 1 86;
87 0 7;
88 0 4;
89 0 7;
90 1 0;
91 1 28;
92 1 0;
93 0 50;
94 0 7;
95 1 0;
96 1 0;
97 0 8;
98 1 2;
99 0 38;
100 1 0;
101 0 15;
102 0 11;
103 1 2;
104 0 104;
105 1 3;
106 0 70;
107 0 107;
108 0 8;
109 1 2;
110 0 17;
111 1 111;
112 1 31;
113 0 4;
114 1 2;
115 

## Model application example

In [11]:
it = iter(test_loader)
next(it)
example_game = next(it)
out_nodes, out_edges = model(example_game.x, example_game.edge_index) 
pred_nodes = out_nodes.argmax(dim=1)
pred_edges = out_edges.argmax(dim=1)

### Output 1: Winning reagions of players 0 and 1

**Predicted regions**

In [12]:
pred_nodes

tensor([1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0,
        0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 0,
        0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0,
        1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0,
        1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0,
        0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1,
        1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0,
        1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1,
        1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0,
        1, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1,
        1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1,
        0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0,
        1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0,

**Actual winning regions (Calculated by pgsolver)**

In [13]:
example_game.y_nodes

tensor([1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0,
        0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 0,
        0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0,
        1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0,
        1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0,
        0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1,
        1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0,
        1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1,
        1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0,
        1, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1,
        1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1,
        0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0,
        1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0,

### Output 2: Winning strategies

A **1** means that the edge belongs to a winning strategy

**Predicted winning strategy**

In [14]:
example_game.edge_index[:,example_game.y_edges == 1]

tensor([[  0,   1,   2,  ..., 664, 665, 666],
        [ 40,   1,  26,  ..., 613, 605, 613]])

**Winning strategy from pgsolver**

In [78]:
example_game.edge_index[:,pred_edges == 1]

tensor([[  9,   9,  20, 101, 117, 117, 138, 138, 138, 138, 292, 292, 292],
        [ 12,  16,  16, 150,  51, 102,  71, 125, 153, 178, 250, 309, 334]])

In [72]:
max(pred_edges)

tensor(1)

In [43]:
example_game.y_edges.shape[0]

17852