In [None]:
import json

import networkx as nx
%matplotlib inline
import matplotlib.pyplot as plt

import parser

In [None]:
# TODO: account for prev/next
import collections

class BoardStateNode:
    def __init__(self, idx, label):
        self.key = idx
        self.label = label
        
    def __repr__(self):
        return self.label
    
    def __hash__(self):
        return hash(self.key)
    
    
class ResetNode:
    def __init__(self, idx):
        self.key = idx
        self.label = "RESET"
        
    def __repr__(self):
        return self.label
    
    def __hash__(self):
        return hash(self.key)


def get_state_graphs(level):
    """
    Get all the state graphs for a given playthough of a given level.
    """
    graphs = []
    for action in level.actions:
        if action["1"]["action_id"] == "victory":
            graphs[-1].graph["victory"] = True

        if action["1"]["action_id"] != "state-path-save":
            continue
            
        graph_detail = json.loads(action["1"]["action_detail"])
        graph = nx.DiGraph(victory=False, reset=False)
        nodes = []
        for idx, node in enumerate(graph_detail["nodes"]):
            if "data" in node and node["data"] == "reset":
                nodes.append("reset")
                graph.graph["reset"] = True
                graph.add_node("reset")
            else:
                nodes.append(str(tuple(sorted(node["data"]["board"]))))
                graph.add_node(nodes[-1], node_data=node["data"])
        
        for edge in graph_detail["edges"]:
            graph.add_edge(nodes[edge["from"]], nodes[edge["to"]])
        
        graphs.append(graph)
    return graphs


def get_complete_state_graphs(level_sequence, level_id):
    """Get all state graphs for all playthroughs of a given level."""
    graphs = []
    for level in level_sequence:
        if level.id == level_id:
            graphs.extend(get_state_graphs(level))
    return graphs


def only_complete_graphs(graphs):
    """
    Filter out graphs not caused by victory or reset. 
    """
    return [graph for graph in graphs if graph.graph["reset"] or graph.graph["victory"]]


def draw_graph(graph, size=(20, 20)):
    # Make the plot bigger
    plt.figure(3,figsize=size)
    nx.draw_networkx(
        graph, 
        with_labels=True,
    )
    
def merge_graphs(graphs):
    nodes = []
    edges = []
    idx_mapping = {}  # (graph_idx, idx) -> node_idx
    node_mapping = {} # (sorted_board_state) -> node_idx
    
    for (graph_idx, graph) in enumerate(graphs):
        for idx, node in enumerate(graph):
            if node == "reset":
                if "reset" not in node_mapping:
                    nodes.append("reset")
                    idx_mapping[(graph_idx, idx)] = len(nodes) - 1
                    node_mapping["reset"] = len(nodes) - 1
                else:
                    idx_mapping[(graph_idx, idx)] = node_mapping["reset"]
            else:
                if node not in node_mapping:
                    nodes.append(node)
                    idx_mapping[(graph_idx, idx)] = len(nodes) - 1
                    node_mapping[node] = len(nodes) - 1
        edges.extend(graph.edges())
    
    graph = nx.DiGraph()
    graph.add_nodes_from(nodes)       
    graph.add_edges_from(edges)
    
    return graph

In [None]:
events, level_sequence = parser.read_events("p2")

In [None]:
graphs = only_complete_graphs(get_state_graphs(level_sequence[75]))

In [None]:
draw_graph(graphs[0])

In [None]:
x = [(i,l) for (i,l) in enumerate(level_sequence) if l.id == 65]
[(i[0], len(i[1].actions)) for i in x]

In [None]:
draw_graph(merge_graphs(only_complete_graphs(get_complete_state_graphs(level_sequence, 56))),
           size=(25, 25))