In [1]:
import csv

import networkx as nx
import numpy as np
import torch

In [2]:
tree = torch.load('resources/hierarchy_raw/imagenet21k/winter21/imagenet21k_miil_tree.pth')
# tree = torch.load('resources/hierarchy_raw/imagenet21k/fall11/imagenet21k_miil_tree.pth')

In [3]:
order = tree['class_list'].tolist()

In [4]:
g = nx.DiGraph()
visited = set()

def add_node_and_ancestors(node):
    if node not in tree['child_2_parent']:
        return
    if node in visited:
        return
    visited.add(node)
    parent = tree['child_2_parent'][node]
    g.add_edge(parent, node)
    add_node_and_ancestors(parent)

for node in order:
    add_node_and_ancestors(node)

len(g), len(g.edges)

(11925, 11924)

In [5]:
# Check unique root node.
[x for x in g if g.in_degree[x] == 0]

['n00001740']

In [6]:
# Check number of leaf nodes.
sum(1 for x in g if g.out_degree[x] == 0)

8152

In [7]:
# Ensure no nodes with multiple parents.
[x for x in g if g.in_degree[x] > 1]

[]

In [8]:
# Check that all classes are present in the graph.
[x for x in order if x not in g]

['n09450163']

In [9]:
# Add a direct edge from the root to any orphan nodes.
# https://github.com/Alibaba-MIIL/ImageNet21K/issues/54
root, = [x for x in g if g.in_degree[x] == 0]
for x in order:
    if x not in g:
        g.add_edge(root, x)

In [10]:
# TODO: Avoid duplication with inat.

def dfs_edges_with_order(g, order):
    visited = set()
    edges = []

    def visit(node):
        if node in visited:
            return
        visited.add(node)
        if not g.in_degree[node]:
            return
        parents = list(g.predecessors(node))
        if len(parents) > 1:
            raise ValueError('multiple parents', node, parents)
        parent, = parents
        visit(parent)
        edges.append((parent, node))
    
    for leaf in order:
        visit(leaf)

    return edges

In [11]:
edges = dfs_edges_with_order(g, order)

In [12]:
with open('resources/hierarchy/imagenet21k_winter21.csv', 'w') as f:
    w = csv.writer(f)
    for edge in edges:
        w.writerow(edge)