# Visualize Taxonomy

We use the [*networkx*](https://networkx.github.io/) Python package for analyzing graphs and the [*graphviz*](https://graphviz.readthedocs.io/) Python package for drawing graphs.

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from typing import Any, Mapping, Tuple

import graphviz as gv
import networkx as nx
import pandas as pd

from classification.json_validator import TaxonNode, build_taxonomy_dict

In [None]:
TAXONOMY_CSV_PATH = '/path/to/camera-traps-private/camera_trap_taxonomy_mapping.csv'

## Load taxonomy CSV

In [None]:
df = pd.read_csv(TAXONOMY_CSV_PATH)
display(df.head())

In [None]:
# display all rows without a taxonomy string
# with pd.option_context('display.max_rows', 100):
#     display(df[df['taxonomy_string'].isna()])

In [None]:
taxon_to_node, label_to_node = build_taxonomy_dict(df)

## Graph helper methods

In [None]:
def build_nx_graph(taxon_to_node: Mapping[Tuple[str, str], TaxonNode]) -> nx.DiGraph:
    g = nx.DiGraph()
    for node_id, taxon_node in taxon_to_node.items():
        g.add_node(node_id)  # node_id is a tuple (taxon_level, taxon_name)
        for child in taxon_node.children:
            child_id = (child.level, child.name)
            g.add_edge(node_id, child_id)
    assert nx.is_directed_acyclic_graph(g)
    return g


def nx_to_gv(nx_g: nx.DiGraph) -> gv.Digraph:
    """Converts a networkX graph to a graphviz graph."""
    gv_g = gv.Digraph()
    for node in nx_g.nodes:  # each node is a tuple
        gv_g.node('\n'.join(node))
    for n1, n2 in nx_g.edges:
        n1 = '\n'.join(n1)
        n2 = '\n'.join(n2)
        gv_g.edge(n1, n2)
    return gv_g


def visualize_subgraph(nx_graph: nx.DiGraph, node: Any) -> gv.Digraph:
    """Returns a gv.DiGraph rooted at node."""
    nx_subgraph_nodes = nx.descendants(nx_graph, node) | {node}
    nx_subgraph = nx_graph.subgraph(nx_subgraph_nodes)
    gv_subgraph = nx_to_gv(nx_subgraph)
    return gv_subgraph


def build_graphviz_graph(taxon_to_node: Mapping[Tuple[str, str], TaxonNode]) -> gv.Digraph:
    g = gv.Digraph()
    g.attr(overlap='false')
    for (taxon_level, taxon_name), taxon_node in taxon_to_node.items():
        taxon_id = f'{taxon_level}\n{taxon_name}'
        g.node(taxon_id)
        for child in taxon_node.children:
            child_id = f'{child.level}\n{child.name}'
            g.edge(taxon_id, child_id)
    return g

## Render entire Taxonomy graph to SVG

May include disconnected components.

In [None]:
gv_g = build_graphviz_graph(taxon_to_node)
gv_g.render('taxonomy', format='svg', cleanup=True)
# display(gv_g)  # this can be huge

## Analyze Taxonomy graph

In [None]:
nx_g = build_nx_graph(taxon_to_node)
print('Number of disconnected components:', nx.number_weakly_connected_components(nx_g))

for subgraph in nx.weakly_connected_components(nx_g):
    # each subgraph is a set of (taxon_level, taxon_name)
    
    # get root node
    subgraph = nx_g.subgraph(subgraph)
    root_taxa = [n for n, d in subgraph.in_degree() if d==0]
    assert len(root_taxa) == 1
    root_taxon = root_taxa[0]

    print(f'Graph rooted at {root_taxon} has {len(subgraph)} nodes')
    print('   ', taxon_to_node[root_taxon])

### Display disconnected component

In [None]:
gv_g = visualize_subgraph(nx_g, node=('phylum', 'tracheophyta'))
display(gv_g)

## Query the graph

### Show the smallest subgraph containing all given dataset labels

In [None]:
labels = [
    ('idfg_swwlf_2019', 'mountain_lion'),
    ('idfg_swwlf_2019', 'bobcat'),
    ('idfg_swwlf_2019', 'cat_domestic'),
    ('idfg_swwlf_2019', 'lynx'),
    ('idfg', 'lion')
]
nodes = [label_to_node[label] for label in labels]
lca_node = TaxonNode.lowest_common_ancestor(set(nodes))
display(lca_node)
display(visualize_subgraph(nx_g, node=(lca_node.level, lca_node.name)))

### Get the set of dataset labels corresponding to this subgraph

In [None]:
lca_node.get_dataset_labels()