# MegaClassifier Categories

This notebook determines the classes for MegaClassifier to train on. At a high level, this notebook does the following:

1. Query MegaDB to figure out how many labeled images there are for each dataset-specific class.
2. Build taxonomy tree from the taxonomy CSV, which maps between dataset-specific classes and the taxonomy hierarchy.
3. Using MegaDB query results, determine the number of images at each level in the taxonomy hierachy.
4. Find the bottom-most "leaves" in the taxonomy hiearachy that have the required minimum threshold of images. These leaves are the MegaClassifier categories.
5. Save these categories to `label_spec.json` to be used in the classification pipeline.
6. (Optional) Graphically plot the categories within the taxonomy hierarchy.

## Imports and constants

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import json
import os
from pprint import pprint
from typing import Any, Dict, Iterable, Mapping, Optional, Sequence, Tuple

import graphviz as gv
import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from tqdm.auto import tqdm

from taxonomy_mapping.taxonomy_graph import TaxonNode, build_taxonomy_graph, dag_to_tree
from data_management.megadb import megadb_utils

In [None]:
# Typically independent of MegaClassifier version
TAXONOMY_CSV_PATH = '../../camera-traps-private/camera_trap_taxonomy_mapping.csv'
IMAGES_PER_LEAF_THRESHOLD = 2000

# Adjusted per MegaClassifier iteration
classifier_name = 'megaclassifier-v0.2'
megaclassifier_base = os.path.join('/datadrive/classifier-training/',classifier_name)
num_images_csv_path = os.path.join(megaclassifier_base,'num_images_per_dataset.csv')
labeled_image_counts_dir = os.path.join(megaclassifier_base,classifier_name + '-labeled_image_counts')
label_spec_output_file = os.path.join(megaclassifier_base,classifier_name + '-label_spec.json')
os.makedirs(megaclassifier_base,exist_ok=True)
os.makedirs(labeled_image_counts_dir,exist_ok=True)

## Load taxonomy CSV and MegaDB

In [None]:
df = pd.read_csv(TAXONOMY_CSV_PATH)
megadb = megadb_utils.MegadbUtils()
ds_table = megadb.get_datasets_table()

## Count labeled images

In [None]:
def get_num_images_per_dataset(datasets: Iterable[str],
                               check_existing: Optional[str] = None
                              ) -> pd.Series:
    """Count the total number of images in a dataset based on MegaDB.

    If the throughput on the sequences container is set to 10,000 RU/s,
    this query should be fairly fast. No more than ~20 minutes for all
    ~40 datasets in MegaDB.

    Args:
        datasets: list of str, names of datasets for which to check the
            number of images
        check_existing: optional str, path to CSV of existing counts
            with exactly two columns ['dataset', 'num_images']

    Returns: pd.Series, indexed by dataset name, value is count
    """
    if check_existing is None:
        counter = pd.Series(index=pd.Index([], name='dataset'),
                            name='num_images', dtype=int)
    else:
        counter = pd.read_csv(check_existing, squeeze=True, index_col='dataset')

    query = 'SELECT VALUE SUM(ARRAY_LENGTH(seq.images)) FROM seq'
    # Equivalent query: 'SELECT VALUE COUNT(1) FROM seq JOIN img IN seq.images'

    for ds in tqdm(sorted(datasets)):
        if ds in counter.index:
            tqdm.write(f'{ds} already in existing CSV. Skipping.')
            continue

        # Sometimes the query will return multiple results (perhaps due to
        # CosmosDB paging?), so we just sum the counts.
        results = list(megadb.query_sequences_table(query=query, partition_key=ds))
        if len(results) > 1:
            tqdm.write(f'Got more than one result for {ds}')
        counter[ds] = sum(results)
        tqdm.write(f'Dataset: {ds}, Count: {counter[ds]}')

    counter.sort_index(inplace=True)
    return counter

num_images_per_dataset = get_num_images_per_dataset(
    datasets=ds_table.keys(), check_existing=None)
num_images_per_dataset.to_csv(num_images_csv_path, index=True)

In [None]:
def count_labeled_images(datasets: Iterable[str],
                         save_dir: str,
                         num_images_csv_path: str                         
                        ) -> None:
    """Count the number of labeled images of each class in MegaDB.

    An image is counted if the following criteria are all met:
    1) Either the sequence is labeled (seq.class) or the image is labeled (img.class).
    2) If the image is labeled, it has exactly 1 class.
    3) Otherwise, if the sequence is labeled, the sequence has exactly 1 class.

    If the throughput on the sequences container is set to 10,000 RU/s,
    this query should take ~1 hour for all ~40 datasets in MegaDB.

    Args:
        datasets: list of str, names of datasets for which to count the
            number of labeled images
        save_dir: str, path to folder to save one output CSV per dataset,
            each CSV has exactly two columns ['class', 'count']
        num_images_csv_path: str, path to CSV with the number of total images per
            dataset, CSV should have exactly two columns ['dataset', 'count']
    """
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
        print('Created query cache dir at:', save_dir)

    num_images_per_dataset = pd.read_csv(
        num_images_csv_path, squeeze=True, index_col='dataset')
    for ds in datasets:
        assert ds in num_images_per_dataset.index

    query = '''
    SELECT
        VALUE [[seq.class, img.class], COUNT(1)]
    FROM sequences seq JOIN img IN seq.images
    WHERE
        (ARRAY_LENGTH(img.class) = 1)
        OR
        (NOT ISDEFINED(img.class) AND ARRAY_LENGTH(seq.class) = 1)
    GROUP BY [seq.class, img.class]
    '''

    for ds in tqdm(sorted(datasets)):
        save_path = os.path.join(save_dir, f'{ds}.csv')
        if os.path.exists(save_path):
            tqdm.write(f'Saved class counts for {ds} already exist. Skipping.')
            continue

        tqdm.write(f'Querying {ds}')
        results = list(megadb.query_sequences_table(query=query, partition_key=ds))

        counter = pd.Series(index=pd.Index([], name='class'), name='count', dtype=int)
        for combined_classes, count in results:
            if len(combined_classes) > 1:
                tqdm.write('- Has both seq and img class. Using img class.')
                tqdm.write(f'- {combined_classes}')
                img_classes = combined_classes[1]
            else:
                img_classes = combined_classes[0]
            assert len(img_classes) == 1
            img_class = img_classes[0]

            if img_class not in counter:
                counter[img_class] = 0
            counter[img_class] += count

        num_labeled_images = counter.sum()
        num_images = num_images_per_dataset[ds]
        assert num_labeled_images <= num_images
        tqdm.write(f'- num labeled images: {num_labeled_images}, num_images: {num_images}')

        counter.sort_index(inplace=True)  # sort counter by class
        counter.to_csv(save_path, index=True)

count_labeled_images(datasets=ds_table.keys(),
                     save_dir=labeled_image_counts_dir,
                     num_images_csv_path=num_images_csv_path)

In [None]:
ds_label_counts = {}  # (dataset, dataset_class) => count
for ds in sorted(ds_table.keys()):
    sr = pd.read_csv(f'labeled_image_counts/{ds}.csv', squeeze=True, index_col='class')
    sr.index = sr.index.map(lambda x: (ds, x))
    ds_label_counts.update(sr.to_dict())

## Build taxonomy tree

1. Build taxonomy DAG.
2. Remove humans. Leave human detection to MegaDetector, instead of MegaClassifier.
3. Convert DAG to tree. For the purposes of MegaClassifier, we require that there be exactly 1 true taxonomy hierarchy.

In [None]:
graph, taxon_to_node, label_to_node = build_taxonomy_graph(df)

In [None]:
human_node = taxon_to_node[('species', 'homo sapiens')]
for n in nx.descendants(graph, human_node):
    graph.remove_node(n)
graph.remove_node(human_node)

In [None]:
tree = dag_to_tree(graph, taxon_to_node)
num_connected_components = nx.number_weakly_connected_components(tree)

root_nodes = []
for component in nx.weakly_connected_components(tree):
    # Each component is a set of nodes
    subgraph = tree.subgraph(component)
    assert nx.is_tree(subgraph)
    root_nodes.append([n for n, d in subgraph.in_degree() if d == 0][0])
assert len(root_nodes) == num_connected_components

print('Total number of nodes:', len(tree.nodes))
print('Number of disconnected components:', num_connected_components)
pprint(root_nodes)

## Populate image counts

In [None]:
def get_total_img_count(node: TaxonNode) -> int:
    """Recursively calculates the total number of images in the subtree
    rooted at the given node.
    """
    if hasattr(node, 'total_img_count'):
        return node.total_img_count
    elif len(node.children) == 0:
        node.total_img_count = node.img_count
        return node.total_img_count
    else:
        children_img_count = sum(get_total_img_count(c) for c in node.children)
        node.total_img_count = node.img_count + children_img_count
        return node.total_img_count


def populate_image_counts(graph: nx.DiGraph,
                          ds_label_counts: Mapping[Tuple[str, str], int],
                          label_to_node: Mapping[Tuple[str, str], TaxonNode]) -> None:
    """Adds 2 properties to every node in graph:
    - img_count: int, number of images exactly for this taxon node
    - total_img_count: int, number of images in the subtree rooted at this node
    """
    for node in graph.nodes:
        node.img_count = 0

    labels_not_found = []
    for label, count in ds_label_counts.items():
        if label not in label_to_node:
            labels_not_found.append(label)
            continue
        node = label_to_node[label]
        if node not in graph.nodes:
            print(f'Node {node} not in graph.nodes')
            continue
        node.img_count += count

    print('labels not found:')
    pprint(labels_not_found)

    for subgraph in nx.weakly_connected_components(graph):
        # Each subgraph is a set of (taxon_level, taxon_name)

        # Get root node
        subgraph = graph.subgraph(subgraph)
        root_nodes = [n for n, d in subgraph.in_degree() if d==0]
        assert len(root_nodes) == 1
        root_node = root_nodes[0]

        print(f'Graph rooted at {root_node} has {len(subgraph.nodes)} nodes')
        print('    Total image count:', get_total_img_count(root_node))

In [None]:
populate_image_counts(tree, ds_label_counts, label_to_node)

## Find leaf nodes containing at least a certain threshold of images

Top-down tree search. We define a "leaf" as a tuple of nodes such that:

1. All nodes in the leaf have the same parent.
2. The sum of the `total_img_count` properties on the nodes exceeds the threshold.
3. None of the children of the nodes has a `total_img_count` greather than the threshold.

In [None]:
def get_leaf_nodes(root_nodes: Iterable[TaxonNode], threshold: int
                  ) -> Dict[Tuple[TaxonNode, ...], int]:
    """Given a list of nodes representing the root of trees, returns a dict
    mapping each "leaf" to total image count belonging to that leaf.
    """
    leaf: Tuple[TaxonNode, ...]  # 
    all_leaves = set()
    leaves_to_count: Dict[Tuple[TaxonNode, ...], int] = {}  # tuple of TaxonNode => int

    candidate_nodes = list(root_nodes)  # make a shallow copy
    while len(candidate_nodes) > 0:
        node = candidate_nodes.pop()
        if node.total_img_count < threshold:
            continue

        children = node.children
        if len(children) == 0 or all(c.total_img_count < threshold for c in children):
            assert node not in all_leaves
            leaves_to_count[(node,)] = node.total_img_count
            all_leaves.add(node)
            continue

        # Pop off any children that exceed the threshold
        remaining_children = []
        for c in children:
            if c.total_img_count >= threshold:
                candidate_nodes.append(c)
            else:
                remaining_children.append(c)

        summed_count = sum(c.total_img_count for c in remaining_children)
        if summed_count >= threshold:
            assert all(c not in all_leaves for c in remaining_children)
            leaf = tuple(remaining_children)
            leaves_to_count[leaf] = summed_count
            all_leaves.update(remaining_children)

    return leaves_to_count


def leaf_to_name(leaf: Sequence[TaxonNode]) -> str:
    if len(leaf) == 1:
        node = leaf[0]
        leaf_name = f'{node.level}: {node.name}'
    else:
        parent = leaf[0].parents[0]
        assert TaxonNode.lowest_common_ancestor(leaf) is parent
        leaf_name = f'{parent.level}: {parent.name} (other)'
    return leaf_name

In [None]:
leaves_to_count = get_leaf_nodes(root_nodes, threshold=IMAGES_PER_LEAF_THRESHOLD)
num_leaves = len(leaves_to_count)
print('Total number of leaves:', num_leaves)

In [None]:
plt.hist(list(leaves_to_count.values()), bins=np.arange(0, 40001, 1000).tolist() + [1e8])
plt.xlim(0, 41000)
plt.ylabel('num leaves')
plt.xlabel('num images in leaf')
plt.show()

In [None]:
# Sort leaves by count
display({
    leaf_to_name(leaf): leaves_to_count[leaf]
    for leaf in sorted(leaves_to_count.keys(), key=leaves_to_count.__getitem__, reverse=True)
})

## Write out to label spec JSON

In [None]:
def leaves_to_label_spec(leaves: Iterable[Sequence[TaxonNode]],
                         max_count: Optional[int] = None) -> Dict[str, Any]:
    """Creates a classification label specification from a list of leaves."""
    label_spec = {}
    for leaf in leaves:
        class_name = leaf_to_name(leaf)
        assert class_name not in label_spec

        taxa_list = []
        for node in leaf:
            taxa_list.append({
                'level': node.level,
                'name': node.name
            })

        taxa_dict: Dict[str, Any] = {'taxa': taxa_list}
        if max_count is not None:
            taxa_dict['max_count'] = max_count
        label_spec[class_name] = taxa_dict

    # Sort label_spec by class_name
    label_spec = {
        class_name: label_spec[class_name]
        for class_name in sorted(label_spec.keys())
    }
    return label_spec

In [None]:
# Target dataset size: 1.4M images (roughly the same size as ImageNet)
#
# We assume that we get one good crop every three images.
#
# Even so, this will result in a smaller dataset, because not all classes even have
# `crops_per_class` images to begin with.

crops_per_class = 1.4e6 / num_leaves
label_spec = leaves_to_label_spec(leaves_to_count.keys(), max_count=int(crops_per_class * 3))
with open(label_spec_output_file, 'w') as f:
    json.dump(label_spec, f, indent=1)

## Plot classes graph

In [None]:
def draw_leaves_graph(tree: nx.DiGraph,
                      leaves_to_count: Mapping[Tuple[TaxonNode, ...], int]
                     ) -> gv.Digraph:
    """
    See:
    
      http://www.graphviz.org/doc/info/attrs.html
    
    ...for a description of the different graphviz attributes.
    """
    # create a nx.DiGraph containing only the leaves and their ancestors
    subtree_nodes = set()
    for leaf in leaves_to_count:
        for node in leaf:
            subtree_nodes.add(node)
            for ancestor in nx.ancestors(tree, node):
                subtree_nodes.add(ancestor)
    subtree = tree.subgraph(subtree_nodes)

    # create a gv.Digraph
    gv_g = gv.Digraph(
        graph_attr=dict(overlap='false', concentrate='true', ranksep='2'),
        node_attr=dict(margin='0', width='0', height='0'))

    # add the leaves
    for i, (leaf, count) in enumerate(leaves_to_count.items()):
        name = f'cluster_{i}'
        leaf_label = leaf_to_name(leaf) + '\n' + f'count: {count}'
        attr = dict(color='blue', label=leaf_label)
        node_attr = dict(style='filled')
        with gv_g.subgraph(name=name, graph_attr=attr, node_attr=node_attr) as c:
            for n in leaf:
                n_id = f'{n.level}\n{n.name}'
                c.node(n_id)
                subtree_nodes.remove(n)

    # add the remaining nodes
    for n in subtree_nodes:
        n_id = f'{n.level}\n{n.name}'
        gv_g.node(n_id)

    # add the edges
    for n1, n2 in subtree.edges:
        n1_id = f'{n1.level}\n{n1.name}'
        n2_id = f'{n2.level}\n{n2.name}'
        gv_g.edge(n1_id, n2_id)

    return gv_g

In [None]:
leaves_graph = draw_leaves_graph(tree, leaves_to_count=leaves_to_count)
leaves_graph.render('megaclassifier_groups', format='svg', cleanup=True)
leaves_graph.render('megaclassifier_groups', format='pdf', cleanup=True)
display(leaves_graph)