In [1]:
import csv

import numpy as np

import hier

In [2]:
with open('resources/hierarchy/inat21.csv') as f:
    tree, node_names = hier.make_hierarchy_from_edges(hier.load_edges(f))

In [3]:
def split_level(rng, depth, num_folds):
    node_to_children = tree.children()
    sizes = tree.num_leaf_descendants()
    parent_nodes = np.flatnonzero(tree.depths() == depth - 1)

    fold_subsets = [list() for _ in range(num_folds)]
    for parent in parent_nodes:
        # Shuffle the child nodes.
        children = node_to_children[parent]
        order = rng.permutation(len(children))
        children = children[order]
        child_sizes = sizes[children]
        cumsum = np.concatenate(([0], np.cumsum(child_sizes)))
        center_size = (cumsum[:-1] + cumsum[1:]) / 2
        center_bin = center_size * num_folds / np.sum(child_sizes)
        nearest_bin = np.round(-0.5 + center_bin).astype(int)
        bins = [children[nearest_bin == i] for i in range(num_folds)]
        # Permute the bins too, to avoid rounding bias.
        bins = [bins[i] for i in rng.permutation(num_folds)]
        for i in range(num_folds):
            fold_subsets[i].extend(bins[i])

    return [np.sort(subset) for subset in fold_subsets]

In [4]:
seed = 0

sizes = tree.num_leaf_descendants()

for num_folds in [2, 3]:
    for depth in [4, 5, 6, 7]:
        partitions = split_level(np.random.default_rng(seed), depth, num_folds)
        print(f'folds {num_folds}, depth {depth}:', [np.sum(sizes[nodes]) for nodes in partitions])

        for i in range(num_folds):
            # Take subtree using partition.
            subtree, node_subset = hier.rooted_subtree_spanning(tree, partitions[i])
            subtree_names = [node_names[i] for i in node_subset]
            # Write subtree to file.
            subtree_edges = [(subtree_names[i], subtree_names[j]) for i, j in subtree.edges()]
            with open(f'resources/subtree/inat21_partition_d{depth}_n{num_folds}_i{i}.csv', 'w') as f:
                csv.writer(f).writerows(subtree_edges)

            # if num_folds > 2:
            # Do same for complement.
            complement = np.concatenate(partitions[:i] + partitions[i+1:])
            subtree, node_subset = hier.rooted_subtree_spanning(tree, complement)
            subtree_names = [node_names[i] for i in node_subset]
            # Write subtree to file.
            subtree_edges = [(subtree_names[i], subtree_names[j]) for i, j in subtree.edges()]
            with open(f'resources/subtree/inat21_partition_d{depth}_n{num_folds}_i{i}_c.csv', 'w') as f:
                csv.writer(f).writerows(subtree_edges)

folds 2, depth 4: [4839, 5161]
folds 2, depth 5: [5062, 4938]
folds 2, depth 6: [4919, 5081]
folds 2, depth 7: [4981, 5019]
folds 3, depth 4: [2767, 3011, 4222]
folds 3, depth 5: [2919, 3212, 3869]
folds 3, depth 6: [3345, 3365, 3290]
folds 3, depth 7: [3304, 3349, 3347]
