In [2]:
from modules.dataloader import load_data

boards_train, labels_train, boards_test, labels_test = load_data("../hex_data.csv", 0.9)
dimensions = 5

Number of boards: 3000
First board: OXXEOOOXEXXOXXEOOXEOEXXOO
First label: 0
Training samples:  2700
Test samples:  300
First training board:  OXXEOOOXEXXOXXEOOXEOEXXOO
First training label:  0
First test board:  XOXOOXOOXXXOEXXOOOXXOOOEX
First test label:  1


In [6]:
# setting up training graphs

from GraphTsetlinMachine.graphs import Graphs

#symbols for nodes
symbols = ["X", "O", "E"]
for i in range(dimensions):
    symbols.append(f"Row:{i}")
    symbols.append(f"Column:{i}")

symbols.extend([
    "TopEdge", "BottomEdge", "LeftEdge", "RightEdge",
    "NearTop", "MidTop", "FarTop",
    "NearLeft", "MidLeft", "FarLeft",
    "RowXCountHigh", "RowXCountMed", "RowXCountLow",
    "RowOCountHigh", "RowOCountMed", "RowOCountLow",
    "ColumnXCountHigh", "ColumnXCountMed", "ColumnXCountLow",
    "ColumnOCountHigh", "ColumnOCountMed", "ColumnOCountLow"
])


#initialize the graphs object
graphs_train = Graphs(
    number_of_graphs=len(boards_train),
    symbols=symbols
)

#e.g. 49 for 7x7 board
number_of_nodes = dimensions * dimensions

#set the number of nodes for each graph
for graph_id in range(len(boards_train)):
    graphs_train.set_number_of_graph_nodes(graph_id, number_of_nodes)


#function to get neighbors of a node (cell)
def get_neighbors(row, column, dimensions):
    neighbors = []
    potential_neighbors = [
        (row - 1, column),  #north-west
        (row + 1, column),  #south-east
        (row, column - 1),  #west
        (row, column + 1),  #east
        (row - 1, column + 1),  #north-east
        (row + 1, column - 1),  #south-west
    ]

    for neighbor_row, neighbor_column in potential_neighbors:
        if 0 <= neighbor_row < dimensions and 0 <= neighbor_column < dimensions:
            neighbors.append((neighbor_row, neighbor_column))

    return neighbors


graphs_train.prepare_node_configuration()

#add the nodes to each graph
for graph_id, board in enumerate(boards_train):
    for node_id in range(number_of_nodes):
        row = node_id // dimensions
        column = node_id % dimensions
        neighbors = get_neighbors(row, column, dimensions)

        graphs_train.add_graph_node(graph_id, node_id, len(neighbors))

graphs_train.prepare_edge_configuration()

#add the edges
for graph_id, board in enumerate(boards_train):
    for node_id in range(number_of_nodes):
        row = node_id // dimensions
        column = node_id % dimensions
        neighbors = get_neighbors(row, column, dimensions)

        for (neighbor_row, neighbor_column) in neighbors:
            neighbor_id = neighbor_row * dimensions + neighbor_column

            graphs_train.add_graph_node_edge(graph_id, node_id, neighbor_id, "adjacent_cell")

top_third = dimensions / 3
two_thirds = 2 * dimensions / 3
left_third = dimensions / 3
two_thirds_col = 2 * dimensions / 3

count_one_third = dimensions / 3
count_two_thirds = 2 * dimensions / 3


# For row/column counts
row_x_count = [ [0]*dimensions for _ in range(len(boards_train)) ]
row_o_count = [ [0]*dimensions for _ in range(len(boards_train)) ]
col_x_count = [ [0]*dimensions for _ in range(len(boards_train)) ]
col_o_count = [ [0]*dimensions for _ in range(len(boards_train)) ]

# Populate counts
for g_id, b in enumerate(boards_train):
    for nid in range(number_of_nodes):
        r = nid // dimensions
        c = nid % dimensions
        val = b[nid]
        if val == 'X':
            row_x_count[g_id][r] += 1
            col_x_count[g_id][c] += 1
        elif val == 'O':
            row_o_count[g_id][r] += 1
            col_o_count[g_id][c] += 1


#add the node properties
for graph_id, board in enumerate(boards_train):
    for node_id in range(number_of_nodes):
        row = node_id // dimensions
        column = node_id % dimensions

        #X means cell is occupied by player 0, O is player 1, and E is empty
        cell_value = board[node_id]
        if cell_value == 'X':
            graphs_train.add_graph_node_property(graph_id, node_id, "X")
        elif cell_value == 'O':
            graphs_train.add_graph_node_property(graph_id, node_id, "O")
        elif cell_value == 'E':
            graphs_train.add_graph_node_property(graph_id, node_id, "E")

        graphs_train.add_graph_node_property(graph_id, node_id, f"Row:{row}")
        graphs_train.add_graph_node_property(graph_id, node_id, f"Column:{column}")

        if row == 0:
            graphs_train.add_graph_node_property(graph_id, node_id, "TopEdge")
        if row == dimensions - 1:
            graphs_train.add_graph_node_property(graph_id, node_id, "BottomEdge")
        if column == 0:
            graphs_train.add_graph_node_property(graph_id, node_id, "LeftEdge")
        if column == dimensions - 1:
            graphs_train.add_graph_node_property(graph_id, node_id, "RightEdge")


        # Vertical buckets
        if row < top_third:
            graphs_train.add_graph_node_property(graph_id, node_id, "NearTop")
        elif row < two_thirds:
            graphs_train.add_graph_node_property(graph_id, node_id, "MidTop")
        else:
            graphs_train.add_graph_node_property(graph_id, node_id, "FarTop")

        # Horizontal buckets
        if column < left_third:
            graphs_train.add_graph_node_property(graph_id, node_id, "NearLeft")
        elif column < two_thirds_col:
            graphs_train.add_graph_node_property(graph_id, node_id, "MidLeft")
        else:
            graphs_train.add_graph_node_property(graph_id, node_id, "FarLeft")


        rx_val = row_x_count[graph_id][row]
        if rx_val < count_one_third:
            graphs_train.add_graph_node_property(graph_id, node_id, "RowXCountLow")
        elif rx_val < count_two_thirds:
            graphs_train.add_graph_node_property(graph_id, node_id, "RowXCountMed")
        else:
            graphs_train.add_graph_node_property(graph_id, node_id, "RowXCountHigh")

        ro_val = row_o_count[graph_id][row]
        if ro_val < count_one_third:
            graphs_train.add_graph_node_property(graph_id, node_id, "RowOCountLow")
        elif ro_val < count_two_thirds:
            graphs_train.add_graph_node_property(graph_id, node_id, "RowOCountMed")
        else:
            graphs_train.add_graph_node_property(graph_id, node_id, "RowOCountHigh")

        cx_val = col_x_count[graph_id][column]
        if cx_val < count_one_third:
            graphs_train.add_graph_node_property(graph_id, node_id, "ColumnXCountLow")
        elif cx_val < count_two_thirds:
            graphs_train.add_graph_node_property(graph_id, node_id, "ColumnXCountMed")
        else:
            graphs_train.add_graph_node_property(graph_id, node_id, "ColumnXCountHigh")

        co_val = col_o_count[graph_id][column]
        if co_val < count_one_third:
            graphs_train.add_graph_node_property(graph_id, node_id, "ColumnOCountLow")
        elif co_val < count_two_thirds:
            graphs_train.add_graph_node_property(graph_id, node_id, "ColumnOCountMed")
        else:
            graphs_train.add_graph_node_property(graph_id, node_id, "ColumnOCountHigh")



graphs_train.encode()

In [4]:
#setting up the test graphs
#its the same as the training graphs but we need to set them up separately

graphs_test = Graphs(
    number_of_graphs=len(boards_test),
    init_with=graphs_train
)

for graph_id in range(len(boards_test)):
    graphs_test.set_number_of_graph_nodes(graph_id, number_of_nodes)

graphs_test.prepare_node_configuration()

#add nodes
for graph_id, board in enumerate(boards_test):
    for node_id in range(number_of_nodes):
        row = node_id // dimensions
        column = node_id % dimensions
        neighbors = get_neighbors(row, column, dimensions)

        graphs_test.add_graph_node(graph_id, node_id, len(neighbors))

graphs_test.prepare_edge_configuration()

#add node edges
for graph_id, board in enumerate(boards_test):
    for node_id in range(number_of_nodes):
        row = node_id // dimensions
        column = node_id % dimensions
        neighbors = get_neighbors(row, column, dimensions)

        for (neighbor_row, neighbor_column) in neighbors:
            neighbor_id = neighbor_row * dimensions + neighbor_column

            graphs_test.add_graph_node_edge(graph_id, node_id, neighbor_id, "adjacent_cell")

# For row/column counts
row_x_count = [ [0]*dimensions for _ in range(len(boards_test)) ]
row_o_count = [ [0]*dimensions for _ in range(len(boards_test)) ]
col_x_count = [ [0]*dimensions for _ in range(len(boards_test)) ]
col_o_count = [ [0]*dimensions for _ in range(len(boards_test)) ]

# Populate counts
for g_id, b in enumerate(boards_test):
    for nid in range(number_of_nodes):
        r = nid // dimensions
        c = nid % dimensions
        val = b[nid]
        if val == 'X':
            row_x_count[g_id][r] += 1
            col_x_count[g_id][c] += 1
        elif val == 'O':
            row_o_count[g_id][r] += 1
            col_o_count[g_id][c] += 1

#add node properties
for graph_id, board in enumerate(boards_test):
    for node_id in range(number_of_nodes):
        row = node_id // dimensions
        column = node_id % dimensions
        cell_value = board[node_id]

        if cell_value == 'X':
            graphs_test.add_graph_node_property(graph_id, node_id, "X")
        elif cell_value == 'O':
            graphs_test.add_graph_node_property(graph_id, node_id, "O")
        elif cell_value == 'E':
            graphs_test.add_graph_node_property(graph_id, node_id, "E")

        graphs_test.add_graph_node_property(graph_id, node_id, f"Row:{row}")
        graphs_test.add_graph_node_property(graph_id, node_id, f"Column:{column}")

        if row == 0:
            graphs_test.add_graph_node_property(graph_id, node_id, "TopEdge")
        if row == dimensions - 1:
            graphs_test.add_graph_node_property(graph_id, node_id, "BottomEdge")
        if column == 0:
            graphs_test.add_graph_node_property(graph_id, node_id, "LeftEdge")
        if column == dimensions - 1:
            graphs_test.add_graph_node_property(graph_id, node_id, "RightEdge")


        # Vertical buckets
        if row < top_third:
            graphs_test.add_graph_node_property(graph_id, node_id, "NearTop")
        elif row < two_thirds:
            graphs_test.add_graph_node_property(graph_id, node_id, "MidTop")
        else:
            graphs_test.add_graph_node_property(graph_id, node_id, "FarTop")

        # Horizontal buckets
        if column < left_third:
            graphs_test.add_graph_node_property(graph_id, node_id, "NearLeft")
        elif column < two_thirds_col:
            graphs_test.add_graph_node_property(graph_id, node_id, "MidLeft")
        else:
            graphs_test.add_graph_node_property(graph_id, node_id, "FarLeft")


        rx_val = row_x_count[graph_id][row]
        if rx_val < count_one_third:
            graphs_test.add_graph_node_property(graph_id, node_id, "RowXCountLow")
        elif rx_val < count_two_thirds:
            graphs_test.add_graph_node_property(graph_id, node_id, "RowXCountMed")
        else:
            graphs_test.add_graph_node_property(graph_id, node_id, "RowXCountHigh")

        ro_val = row_o_count[graph_id][row]
        if ro_val < count_one_third:
            graphs_test.add_graph_node_property(graph_id, node_id, "RowOCountLow")
        elif ro_val < count_two_thirds:
            graphs_test.add_graph_node_property(graph_id, node_id, "RowOCountMed")
        else:
            graphs_test.add_graph_node_property(graph_id, node_id, "RowOCountHigh")

        cx_val = col_x_count[graph_id][column]
        if cx_val < count_one_third:
            graphs_test.add_graph_node_property(graph_id, node_id, "ColumnXCountLow")
        elif cx_val < count_two_thirds:
            graphs_test.add_graph_node_property(graph_id, node_id, "ColumnXCountMed")
        else:
            graphs_test.add_graph_node_property(graph_id, node_id, "ColumnXCountHigh")

        co_val = col_o_count[graph_id][column]
        if co_val < count_one_third:
            graphs_test.add_graph_node_property(graph_id, node_id, "ColumnOCountLow")
        elif co_val < count_two_thirds:
            graphs_test.add_graph_node_property(graph_id, node_id, "ColumnOCountMed")
        else:
            graphs_test.add_graph_node_property(graph_id, node_id, "ColumnOCountHigh")

graphs_test.encode()

In [5]:
#training logic
from modules.trainer import train

train(
    graphs_train=graphs_train,
    labels_train=labels_train,
    graphs_test=graphs_test,
    labels_test=labels_test,
    dimensions=dimensions,
    number_of_clauses=500,
    T=5000,
    s=1.0,
    depth=2,
    accuracy_threshhold=1,
    epochs=50
)

Initialization of sparse structure.


FigureWidget({
    'data': [{'mode': 'lines',
              'name': 'Training Accuracy',
              'type': 'scatter',
              'uid': '51e6f60b-b85c-4f1c-ae7c-8c7dcacb1286',
              'x': [],
              'y': []},
             {'mode': 'lines',
              'name': 'Testing Accuracy',
              'type': 'scatter',
              'uid': 'bfd38aeb-3fa1-4e76-a3b1-234c6769fc41',
              'x': [],
              'y': []},
             {'line': {'color': 'gray', 'dash': 'dash'},
              'mode': 'lines',
              'name': 'Test Regression',
              'type': 'scatter',
              'uid': '3d2e5f4a-9f46-458e-b325-0c662de4ad51',
              'x': [],
              'y': []}],
    'layout': {'annotations': [{'font': {'size': 18},
                                'showarrow': False,
                                'text': 'Train: - , Test: -',
                                'x': 0.5,
                                'xref': 'paper',
                          