In [1]:
import json
import pprint

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import scipy.stats
import torchvision

import datasets
import hier

In [2]:
with open('resources/hierarchy/inat21.csv') as f:
    tree, node_names = hier.make_hierarchy_from_edges(hier.load_edges(f))

label_order = [node_names[i] for i in tree.leaf_subset()]

In [3]:
def find_mid_level_nodes(tree, min_size, max_size):
    node_size = tree.num_leaf_descendants()
    small_enough = (node_size <= max_size)
    large_enough = (node_size >= min_size)
    parent = tree.parents(root_loop=True)
    # Require that parent is *not* small enough.
    # If parent is small enough, we will use it instead.
    mid_level_mask = (small_enough & ~small_enough[parent]) & large_enough
    return np.flatnonzero(mid_level_mask)

In [4]:
def summarize_cut(tree, min_size, max_size):
    mid_level_subset = find_mid_level_nodes(tree, min_size, max_size)

    # Find leaf descendents of each mid-level node.
    is_leaf = tree.leaf_mask()
    is_ancestor = tree.ancestor_mask()
    leaf_descendants = {u: np.flatnonzero(is_leaf & is_ancestor[u, :]) for u in mid_level_subset}
    leaf_subset = np.sort(np.concatenate(list(leaf_descendants.values())))
    
    # Take sub-tree down to mid-level nodes.
    subtree, node_subset = hier.rooted_subtree_spanning(tree, mid_level_subset)

    print('mid-level nodes:', len(mid_level_subset))
    print('non-trivial internal nodes:', subtree.num_nodes() - np.sum(subtree.num_children() == 1))
    print('total internal nodes:', subtree.num_nodes())
    print('leaf nodes:', len(leaf_subset))

    # Print subtree.
    print()
    print(hier.format_tree(subtree, [node_names[i] for i in node_subset], include_size=True))

In [5]:
# (500, 1000) keeps too few classes (< 20%).
# Missing: Mammalia, Reptilia, Insecta
summarize_cut(tree, 500, 1000)

mid-level nodes: 3
non-trivial internal nodes: 5
total internal nodes: 10
leaf nodes: 1890

Life (3)
├── Animalia
│   └── Chordata
│       └── Aves
│           └── Passeriformes
└── Plantae (2)
    └── Tracheophyta (2)
        ├── Liliopsida
        └── Magnoliopsida
            └── Asterales



In [6]:
# (200, 500) is reasonable (~35%).
summarize_cut(tree, 200, 500)

mid-level nodes: 13
non-trivial internal nodes: 20
total internal nodes: 24
leaf nodes: 3756

Life (13)
├── Animalia (5)
│   ├── Arthropoda (3)
│   │   └── Insecta (3)
│   │       ├── Coleoptera
│   │       ├── Lepidoptera
│   │       │   └── Nymphalidae
│   │       └── Odonata
│   └── Chordata (2)
│       ├── Mammalia
│       └── Reptilia
├── Fungi
└── Plantae (7)
    └── Tracheophyta (7)
        ├── Liliopsida (2)
        │   ├── Asparagales
        │   └── Poales
        └── Magnoliopsida (5)
            ├── Asterales
            │   └── Asteraceae
            ├── Caryophyllales
            ├── Fabales
            ├── Lamiales
            └── Rosales



In [7]:
# (100, 200) is similar (~35%).
summarize_cut(tree, 100, 200)

mid-level nodes: 23
non-trivial internal nodes: 32
total internal nodes: 42
leaf nodes: 3427

Life (23)
├── Animalia (12)
│   ├── Arthropoda (8)
│   │   ├── Arachnida
│   │   └── Insecta (7)
│   │       ├── Hemiptera
│   │       ├── Hymenoptera
│   │       ├── Lepidoptera (4)
│   │       │   ├── Erebidae
│   │       │   ├── Geometridae
│   │       │   ├── Lycaenidae
│   │       │   └── Noctuidae
│   │       └── Odonata
│   │           └── Libellulidae
│   ├── Chordata (3)
│   │   ├── Actinopterygii
│   │   ├── Amphibia
│   │   └── Aves
│   │       └── Charadriiformes
│   └── Mollusca
├── Fungi
│   └── Basidiomycota
│       └── Agaricomycetes
│           └── Agaricales
└── Plantae (10)
    └── Tracheophyta (10)
        ├── Liliopsida (2)
        │   ├── Asparagales
        │   │   └── Orchidaceae
        │   └── Poales
        │       └── Poaceae
        ├── Magnoliopsida (7)
        │   ├── Apiales
        │   ├── Ericales
        │   ├── Gentianales
        │   ├── Lamiales
        │ 

In [8]:
# (100, 500) includes many more classes (~65%).
summarize_cut(tree, 100, 500)

mid-level nodes: 30
non-trivial internal nodes: 39
total internal nodes: 42
leaf nodes: 6449

Life (30)
├── Animalia (16)
│   ├── Arthropoda (10)
│   │   ├── Arachnida
│   │   └── Insecta (9)
│   │       ├── Coleoptera
│   │       ├── Hemiptera
│   │       ├── Hymenoptera
│   │       ├── Lepidoptera (5)
│   │       │   ├── Erebidae
│   │       │   ├── Geometridae
│   │       │   ├── Lycaenidae
│   │       │   ├── Noctuidae
│   │       │   └── Nymphalidae
│   │       └── Odonata
│   ├── Chordata (5)
│   │   ├── Actinopterygii
│   │   ├── Amphibia
│   │   ├── Aves
│   │   │   └── Charadriiformes
│   │   ├── Mammalia
│   │   └── Reptilia
│   └── Mollusca
├── Fungi
└── Plantae (13)
    └── Tracheophyta (13)
        ├── Liliopsida (2)
        │   ├── Asparagales
        │   └── Poales
        ├── Magnoliopsida (10)
        │   ├── Apiales
        │   ├── Asterales
        │   │   └── Asteraceae
        │   ├── Caryophyllales
        │   ├── Ericales
        │   ├── Fabales
        │   ├── G

In [9]:
# (100, 300) keeps almost as many classes (~55%) with less severe imbalance.
# Missing: Asterales (includes sunflower).
summarize_cut(tree, 100, 300)

mid-level nodes: 29
non-trivial internal nodes: 38
total internal nodes: 43
leaf nodes: 5577

Life (29)
├── Animalia (16)
│   ├── Arthropoda (10)
│   │   ├── Arachnida
│   │   └── Insecta (9)
│   │       ├── Coleoptera
│   │       ├── Hemiptera
│   │       ├── Hymenoptera
│   │       ├── Lepidoptera (5)
│   │       │   ├── Erebidae
│   │       │   ├── Geometridae
│   │       │   ├── Lycaenidae
│   │       │   ├── Noctuidae
│   │       │   └── Nymphalidae
│   │       └── Odonata
│   ├── Chordata (5)
│   │   ├── Actinopterygii
│   │   ├── Amphibia
│   │   ├── Aves
│   │   │   └── Charadriiformes
│   │   ├── Mammalia
│   │   └── Reptilia
│   │       └── Squamata
│   └── Mollusca
├── Fungi
│   └── Basidiomycota
└── Plantae (12)
    └── Tracheophyta (12)
        ├── Liliopsida (2)
        │   ├── Asparagales
        │   └── Poales
        ├── Magnoliopsida (9)
        │   ├── Apiales
        │   ├── Caryophyllales
        │   ├── Ericales
        │   ├── Fabales
        │   ├── Gentianales


In [10]:
# (250, 750) keep many fewer classes (~35%).
summarize_cut(tree, 250, 750)

mid-level nodes: 9
non-trivial internal nodes: 14
total internal nodes: 18
leaf nodes: 3713

Life (9)
├── Animalia (3)
│   ├── Arthropoda
│   │   └── Insecta
│   │       └── Odonata
│   └── Chordata (2)
│       ├── Aves
│       │   └── Passeriformes
│       └── Reptilia
├── Fungi
└── Plantae (5)
    └── Tracheophyta (5)
        ├── Liliopsida
        └── Magnoliopsida (4)
            ├── Asterales
            ├── Caryophyllales
            ├── Fabales
            └── Lamiales



In [11]:
# (200, 1000) is somewhere in between. Keeps ~45% of classes.
summarize_cut(tree, 200, 1000)

mid-level nodes: 13
non-trivial internal nodes: 19
total internal nodes: 23
leaf nodes: 4660

Life (13)
├── Animalia (6)
│   ├── Arthropoda (3)
│   │   └── Insecta (3)
│   │       ├── Coleoptera
│   │       ├── Lepidoptera
│   │       │   └── Nymphalidae
│   │       └── Odonata
│   └── Chordata (3)
│       ├── Aves
│       │   └── Passeriformes
│       ├── Mammalia
│       └── Reptilia
├── Fungi
└── Plantae (6)
    └── Tracheophyta (6)
        ├── Liliopsida
        └── Magnoliopsida (5)
            ├── Asterales
            ├── Caryophyllales
            ├── Fabales
            ├── Lamiales
            └── Rosales



In [12]:
def make_two_level_tree(tree, node_names, mid_level_subset):
    is_leaf = tree.leaf_mask()
    is_ancestor = tree.ancestor_mask()

    # Construct graph that contains root -> mid -> leaf.
    g = nx.DiGraph()
    for u in mid_level_subset:
        g.add_edge(node_names[0], node_names[u])
        # Find leaf descendents of node.
        leaf_descendants = np.flatnonzero(is_leaf & is_ancestor[u, :])
        for v in leaf_descendants:
            g.add_edge(node_names[u], node_names[v])

    return g

In [13]:
def flatten_beyond(tree, node_names, mid_level_subset):
    is_leaf = tree.leaf_mask()
    is_ancestor = tree.ancestor_mask()

    # Construct graph that contains root -> ... -> mid -> leaf.
    g = nx.DiGraph()

    # Take sub-tree down to mid-level nodes.
    subtree, node_subset = hier.rooted_subtree_spanning(tree, mid_level_subset)
    for u, v in subtree.edges():
        g.add_edge(node_names[node_subset[u]], node_names[node_subset[v]])

    for u in mid_level_subset:
        # Find leaf descendents of node.
        leaf_descendants = np.flatnonzero(is_leaf & is_ancestor[u, :])
        for v in leaf_descendants:
            g.add_edge(node_names[u], node_names[v])

    return g

In [14]:
# import prepare_hierarchy.util

In [15]:
# mid_level_nodes = find_mid_level_nodes(tree, 100, 200)
# subg = make_two_level_tree(tree, node_names, mid_level_nodes)
# label_subset = [x for x in label_order if x in subg]
# subg_edges = prepare_hierarchy.util.dfs_edges_with_order(subg, label_subset)

In [16]:
# mid_level_nodes = find_mid_level_nodes(tree, 100, 200)
# subg = flatten_beyond(tree, node_names, mid_level_nodes)
# label_subset = [x for x in label_order if x in subg]
# subg_edges = prepare_hierarchy.util.dfs_edges_with_order(subg, label_subset)