In [None]:
import json

import networkx as nx
%matplotlib inline
import matplotlib.pyplot as plt

import parser

In [None]:
# TODO: account for prev/next
import collections

class BoardStateNode:
    def __init__(self, idx, label):
        self.key = idx
        self.label = label
        
    def __repr__(self):
        return self.label
    
    def __hash__(self):
        return hash(self.key)
    
    
class ResetNode:
    def __init__(self, idx):
        self.key = idx
        self.label = "RESET"
        
    def __repr__(self):
        return self.label
    
    def __hash__(self):
        return hash(self.key)


def get_state_graphs(level):
    """
    Get all the state graphs for a given playthough of a given level.
    """
    graphs = []
    for action in level.actions:
        if action["1"]["action_id"] == "victory":
            graphs[-1].graph["victory"] = True

        if action["1"]["action_id"] != "state-path-save":
            continue
            
        graph_detail = json.loads(action["1"]["action_detail"])
        graph = nx.DiGraph(victory=False, reset=False)
        nodes = []
        for idx, node in enumerate(graph_detail["nodes"]):
            if "data" in node and node["data"] == "reset":
                nodes.append("reset")
                graph.graph["reset"] = True
                graph.add_node("reset")
            else:
                # Nodes are labeled with the sorted string representation
                # of the board state
                nodes.append(repr(list(sorted(node["data"]["board"]))))
                graph.add_node(nodes[-1], node_data=node["data"])
        
        for edge in graph_detail["edges"]:
            graph.add_edge(nodes[edge["from"]], nodes[edge["to"]])
        
        graphs.append(graph)
    return graphs


def get_complete_state_graphs(level_sequence, level_id):
    """Get all state graphs for all playthroughs of a given level."""
    graphs = []
    for level in level_sequence:
        if level.id == level_id:
            graphs.extend(get_state_graphs(level))
    return graphs


def only_complete_graphs(graphs):
    """
    Filter out graphs not caused by victory or reset. 
    """
    return [graph for graph in graphs if graph.graph["reset"] or graph.graph["victory"]]


def draw_graph(graph, size=(20, 20)):
    # Make the plot bigger
    plt.figure(3,figsize=size)
    if graph.graph.get("weighted"):
        high = [(u,v) for (u, v, d) in graph.edges(data=True) if d["weight"] > 0.5]
        med = [(u,v) for (u, v, d) in graph.edges(data=True) if 0.25 < d['weight'] <= 0.5]
        low = [(u,v) for (u, v, d) in graph.edges(data=True) if d['weight'] <= 0.25]
        pos = nx.shell_layout(graph)
        nx.draw_networkx_edges(graph, pos, edgelist=high, width=5)
        nx.draw_networkx_edges(graph, pos, edgelist=med, width=3,
                               alpha=0.8, style="dashed")
        nx.draw_networkx_edges(graph, pos, edgelist=low,
                               width=1, alpha=0.5, style="dashed")
        
        terminal = [node for (node, d) in graph.nodes(data=True) if d["terminal"]]
        initial = [node for (node, d) in graph.nodes(data=True) if d["initial"]]
        high = [node for (node, d) in graph.nodes(data=True) if d["weight"] > 0.5]
        low = [node for (node, d) in graph.nodes(data=True) if d["weight"] <= 0.5]
        nx.draw_networkx_nodes(graph, pos, high, node_size=300)
        nx.draw_networkx_nodes(graph, pos, low, node_size=150)
        nx.draw_networkx_nodes(graph, pos, terminal, node_color="blue")
        nx.draw_networkx_nodes(graph, pos, initial, node_color="green")
        
        nx.draw_networkx_labels(graph, pos, font_size=16, font_family='sans-serif')
    else:
        nx.draw_networkx(
            graph, 
            with_labels=True,
        )
    plt.axis('off')
    plt.tight_layout()
    
def merge_graphs(graphs):
    nodes = []
    node_mapping = {} # board_label -> node_idx
    edge_weights = collections.Counter()
    node_weights = collections.Counter()
    
    for (graph_idx, graph) in enumerate(graphs):
        for idx, node in enumerate(graph):
            node_weights[node] += 1
            if node not in node_mapping:
                nodes.append(node)
                node_mapping[node] = len(nodes) - 1
        
        for edge in graph.edges():
            edge_weights[edge] += 1
    
    graph = nx.DiGraph()
    max_node_count = max(node_weights.values())
    for node, count in node_weights.items():
        graph.add_node(node, weight=count/max_node_count, count=count,
                       terminal=True, initial=True)
    max_count = max(edge_weights.values())
    for edge, count in edge_weights.items():
        graph.node[edge[0]]["terminal"] = False
        graph.node[edge[1]]["initial"] = False
        graph.add_edge(*edge, weight=count/max_count, count=count)
        
    graph.graph["weighted"] = True
    
    return graph

In [None]:
events, level_sequence = parser.read_events("p2")

In [None]:
graphs = only_complete_graphs(get_state_graphs(level_sequence[75]))

In [None]:
draw_graph(graphs[0])

In [None]:
draw_graph(merge_graphs(only_complete_graphs(get_complete_state_graphs(level_sequence, 56))),
           size=(25, 25))

In [None]:
draw_graph(merge_graphs(only_complete_graphs(get_complete_state_graphs(level_sequence, 65))),
           size=(25, 25))

In [None]:
draw_graph(merge_graphs(only_complete_graphs(get_complete_state_graphs(level_sequence, 25))),
           size=(25, 25))