In [1]:
import csv
import pathlib

import networkx as nx
import numpy as np
import torch

import util

In [2]:
RESOURCES_DIR = pathlib.Path('../resources')

In [3]:
tree_data = torch.load(RESOURCES_DIR / 'hierarchy_raw/imagenet21k/winter21/imagenet21k_miil_tree.pth')

In [4]:
tree_data.keys()

dict_keys(['class_list', 'child_2_parent', 'class_tree_list', 'class_description'])

In [5]:
order = tree_data['class_list'].tolist()

In [6]:
order == sorted(order)

True

In [7]:
g = nx.DiGraph()
visited = set()

def add_node_and_ancestors(node):
    if node not in tree_data['child_2_parent']:
        return
    if node in visited:
        return
    visited.add(node)
    parent = tree_data['child_2_parent'][node]
    g.add_edge(parent, node)
    add_node_and_ancestors(parent)

for node in order:
    add_node_and_ancestors(node)

len(g), len(g.edges)

(11925, 11924)

In [8]:
# Check unique root node.
[x for x in g if g.in_degree[x] == 0]

['n00001740']

In [9]:
# Check number of leaf nodes.
# Note: Some labels in ImageNet21k are not leaf nodes!
sum(1 for x in g if g.out_degree[x] == 0)

8152

In [10]:
# Ensure no nodes with multiple parents.
[x for x in g if g.in_degree[x] > 1]

[]

In [11]:
leaf_set = set(x for x in g if g.out_degree[x] == 0)
len(leaf_set)

8152

In [12]:
leaf_order = [x for x in order if x in leaf_set]

In [13]:
# Check for classes that are not present in the graph.
[x for x in order if x not in g]

['n09450163']

In [14]:
# # Add a direct edge from the root to any orphan nodes.
# # https://github.com/Alibaba-MIIL/ImageNet21K/issues/54
# root, = [x for x in g if g.in_degree[x] == 0]
# for x in order:
#     if x not in g:
#         g.add_edge(root, x)

In [15]:
edges = util.dfs_edges_with_order(g, leaf_order)

In [16]:
with open(RESOURCES_DIR / 'hierarchy/imagenet21k.csv', 'w') as f:
    w = csv.writer(f)
    for edge in edges:
        w.writerow(edge)

In [17]:
with open(RESOURCES_DIR / 'hierarchy/imagenet21k_subset.txt', 'w') as f:
    for x in leaf_order:
        print(x, file=f)