In [None]:
import networkx as nx
import pandas as pd


def _compute_graph_layout(graph):
    path_length = nx.shortest_path_length(graph)
    distances = pd.DataFrame(index=graph.nodes(), columns=graph.nodes())
    for row, data in path_length:
        for col, dist in data.items():
            distances.loc[row, col] = dist
    distances = distances.fillna(distances.max().max())
    return nx.kamada_kawai_layout(graph, dist=distances.to_dict())


class Map:
    def __init__(self):
        self._graph = nx.Grap()
        self._node_labels = {}

    def _compute_graph_layout(self):
        path_length = nx.shortest_path_length(self._graph)
        distances = pd.DataFrame(index=graph.nodes(), columns=graph.nodes())
        for row, data in path_length:
            for col, dist in data.items():
                distances.loc[row, col] = dist
        distances = distances.fillna(distances.max().max())
        return nx.kamada_kawai_layout(graph, dist=distances.to_dict())

    def add_node(self, node_for_adding, **attr):
        """
        """
        self._graph.add_node(node_for_adding, **attr)

    def add_nodes_from(self, nodes_for_adding, **attr):
        """
        """
        self._graph.add_nodes_from(self, nodes_for_adding, **attr)

    def add_edge(self, u_of_edge, v_of_edge, **attr):
        """
        """
        self._graph.add_edge(self, u_of_edge, v_of_edge, **attr)

    def add_edges_from(self, ebunch_to_add, **attr):
        """
        """
        self._graph().add_edges_from(self, ebunch_to_add, **attr)

    def plot(self):
        """
        """
        plt.clf()

        layout = self._compute_graph_layout()

        nx.draw_networkx_nodes(
            graph, layout, nodelist=nodes, node_size=node_size, node_color="red"
        )

        # x_left, x_right = plt.xlim()
        # y_left, y_right = plt.ylim()
        # delta_x = (x_right - x_left) * 0.01
        # delta_y = (y_right - y_left) * 0.01
        # for node in nodes:
        #     x_pos, y_pos = layout[node]
        #     plt.text(
        #         x_pos + delta_x,
        #         y_pos + delta_y,
        #        node,
        #         size=font_size,
        #        ha="left",
        #         va="bottom",
        #         bbox=dict(boxstyle="square", ec="gray", fc="white",),
        #     )

        nx.draw_networkx_edges(graph, layout, width=1)
        plt.axis("off")

    def ocurrence_map(
        self,
        min_value=None,
        top_links=None,
        figsize=(10, 10),
        font_size=12,
        factor=None,
        size=(300, 1000),
    ):
        """Cluster map for ocurrence and co-ocurrence matrices.
        >>> import pandas as pd
        >>> import matplotlib.pyplot as plt
        >>> from techminer.datasets import load_test_cleaned
        >>> rdf = load_test_cleaned().data
        >>> rdf.co_ocurrence(
        ...    column_r='Authors', 
        ...    column_c='Authors', 
        ...    sep_r=',', 
        ...    sep_c=',',
        ...    top_n=10
        ... ).heatmap()
        >>> plt.savefig('./figs/heatmap-ocurrence-map.jpg')
        
        .. image:: ../figs/heatmap-ocurrence-map.jpg
            :width: 600px
            :align: center
        >>> rdf.co_ocurrence(
        ...    column_r='Authors', 
        ...    column_c='Authors', 
        ...    sep_r=',', 
        ...    sep_c=',',
        ...    top_n=10
        ... ).ocurrence_map(
        ...    figsize=(11,11),
        ...    font_size=10,
        ...    factor = 0.1,
        ...    size=(300,1000)
        ... )
        >>> plt.savefig('./figs/ocurrence-map.jpg')
        
        .. image:: ../figs/ocurrence-map.jpg
            :width: 600px
            :align: center
        """

        if self._call not in ["ocurrence", "co_ocurrence"]:
            Exception("Invalid call for result of function:" + self._call)

        ## figure properties
        plt.figure(figsize=figsize)

        ## graph
        graph = nx.Graph()

        terms_r = list(set(self.tomatrix().index.tolist()))
        terms_c = list(set(self.tomatrix().columns.tolist()))

        nodes = list(set(terms_r + terms_c))
        nodes = [cut_text(x) for x in nodes]
        graph.add_nodes_from(nodes)

        if sorted(terms_r) != sorted(terms_c):

            numnodes = [str(i) for i in range(len(self))]
            graph.add_nodes_from(numnodes)

            for idx, row in self.iterrows():
                graph.add_edge(row[0], str(idx))
                graph.add_edge(row[1], str(idx))

            labels = {str(idx): row[2] for idx, row in self.iterrows()}

        else:

            mtx = self.tomatrix()
            edges = []
            labels = {}

            n = 0
            for idx_r, row in enumerate(mtx.index.tolist()):
                for idx_c, col in enumerate(mtx.columns.tolist()):

                    if idx_c < idx_r:
                        continue

                    if mtx.at[row, col] > 0:
                        edges += [(row, str(n)), (col, str(n))]
                        labels[str(n)] = mtx.at[row, col]
                        n += 1

            numnodes = [str(i) for i in range(n)]
            graph.add_nodes_from(numnodes)

            for a, b in edges:
                graph.add_edge(a, b)

        ## graph layout
        layout = _compute_graph_layout(graph)

        ## draw terms nodes
        node_size = [int(n[n.find("[") + 1 : -1]) for n in nodes]
        node_size = [
            size[0]
            + (n - min(node_size))
            / (max(node_size) - min(node_size))
            * (size[1] - size[0])
            for n in node_size
        ]
        nx.draw_networkx_nodes(
            graph, layout, nodelist=nodes, node_size=node_size, node_color="red"
        )

        x_left, x_right = plt.xlim()
        y_left, y_right = plt.ylim()
        delta_x = (x_right - x_left) * 0.01
        delta_y = (y_right - y_left) * 0.01
        for node in nodes:
            x_pos, y_pos = layout[node]
            plt.text(
                x_pos + delta_x,
                y_pos + delta_y,
                node,
                size=font_size,
                ha="left",
                va="bottom",
                bbox=dict(boxstyle="square", ec="gray", fc="white",),
            )

        # nx.draw_networkx_labels(
        #     graph,
        #     layout,
        #     labels={t:t for t in terms},
        #     bbox=dict(facecolor='none', edgecolor='lightgray', boxstyle='round'))

        ## draw quantity nodes
        node_size = [int(labels[n]) for n in labels.keys()]
        node_size = [
            size[0]
            + (n - min(node_size))
            / (max(node_size) - min(node_size))
            * (size[1] - size[0])
            for n in node_size
        ]
        nx.draw_networkx_nodes(
            graph,
            layout,
            nodelist=numnodes,
            node_size=node_size,
            node_color="lightblue",
        )

        nx.draw_networkx_labels(graph, layout, labels=labels, font_color="black")

        ## edges
        nx.draw_networkx_edges(graph, layout, width=1)
        plt.axis("off")