# Training a GNN to solve genus 2 Minesweeper

Solving genus 2 minesweeper is a much simpler task than solving genus g minesweeper. In genus g Minesweeper, a tile can have any number of neighbors, creating issues in how to represent number of mines. In genus 2, this number is only up to 23. Still somewhat difficult to cope with, but is much more manageable.

This notebook serves as an initial test for the results I can get with genus 2 minesweeper. Later on, I must figure out how to extend this to genus g.

# Data Loading 

This section serves to define functions for loading data for use by the nn.

In [1]:
from train_generator import *

In [6]:
# Need to set the dgl backend
import os

os.environ["DGLBACKEND"] = "mxnet"
import mxnet as mx
import dgl

# Create a dgl graph from a single piece of generator output (single (X, Y) pair)
def g_from_gen(gen_out):
    # Split data
    (edges, tile_data), y = gen_out
    edges = list(edges)
    
    # Process data mxnet format
    from_node = mx.nd.array([edge[0] for edge in edges], dtype="int32")
    to_node   = mx.nd.array([edge[1] for edge in edges], dtype="int32")
    num_nodes = len(y)
    tile_states = mx.np.array([tile_data[i][0] for i in range(num_nodes)])
    tile_mines =  mx.np.array([tile_data[i][1] for i in range(num_nodes)])
    tile_y = mx.np.array([[1 if hasmine else 0] for hasmine in y])
    
    # Generate graph
    g = dgl.graph((from_node, to_node), num_nodes=num_nodes)
    g = dgl.add_self_loop(g)
    
    # Set up graph attributes
    ndata_state = mx.np.zeros((num_nodes, 3))
    ndata_mines = mx.np.zeros((num_nodes, 23))
    ndata_y = mx.np.zeros((num_nodes, 1))
    
    ndata_state[mx.np.arange(tile_states.size), tile_states] = 1
    ndata_state[mx.np.arange(tile_mines.size), tile_mines] = 1
    ndata_y = tile_y
    
    g.ndata["state"] = ndata_state
    g.ndata["mines"] = ndata_mines
    g.ndata["y"] = ndata_y
    
    return g

start = time.time()
test_gen = train_generator()
test_out = next(test_gen)[0]
g_from_gen(test_out)

Graph(num_nodes=800, num_edges=1696,
      ndata_schemes={'state': Scheme(shape=(3,), dtype=dtype('float32')), 'mines': Scheme(shape=(23,), dtype=dtype('float32')), 'y': Scheme(shape=(1,), dtype=dtype('float32'))}
      edata_schemes={})

# Model definition

Definition of GNN models to test

In [None]:
from dgl.nn import GraphConv
from mxnet import gluon, init, npx, autograd
from mxnet.gluon import nn

class GCN_1(nn.Block):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.dense_nstate = nn.Dense(16, in_units=3)
        self.dense_nmines = nn.Dense(16, in_units=23)
        self.dense_nmerge = nn.Dense(64)
        self.conv1 = GraphConv(64, 128)
        self.conv2 = GraphConv(128, 256)
        self.conv3 = GraphConv(256, 512)
        self.conv4 = GraphConv(512, 1)
        self.relu = nn.Activation("relu")
        self.sig = nn.Activation("sigmoid")
    
    def forward(self, g, in_nstates, in_nmines):
        h_nstates = self.relu(self.dense_nstate(in_nstates))
        h_nmines = self.relu(self.dense_nmines(in_nmines))
        h_merged = mx.np.concat((h_nstates, h_nmines), axis=1)
        h_merged = self.relu(self.dense_nmerge(h_merged))
        
        h_graph = h_merged
        for layer in [self.conv1, self.conv2, self.conv3]:
            h_graph = self.relu(layer(h_graph))
        
        y_hat = self.sig(self.conv4(h_graph))
        return y_hat
        
    