In [35]:
from robustness.tools.breeds_helpers import setup_breeds
from robustness.tools.breeds_helpers import ClassHierarchy
from robustness.tools.breeds_helpers import BreedsDatasetGenerator
from robustness.tools.breeds_helpers import print_dataset_info
import os
import random
import torch

In [36]:
data_dir = '/mnt/nvme0n1p2/data/ImageNet-1K'
info_dir = '../datasets/breeds_info'
num_workers = 1
batch_size = 5

if not (os.path.exists(info_dir) and len(os.listdir(info_dir))):
    print("Downloading class hierarchy information...")
    setup_breeds(info_dir)

hier = ClassHierarchy(info_dir)
DG = BreedsDatasetGenerator(info_dir)

In [37]:
f = './files/dataset_info.pt'
dataset_info = torch.load(f)
level = 5 # Level to pick superclasses from
n_superclasses = 10 # Number of superclasses
sub_per_super = 2 # Number of subclasses per superclass
root_list = ['dummy63', 'n03051540'] # Specify roots to use
use_root = False # Whether to use root_list
save = False # Whether to save new dataset to dataset info file
root_level = 3

In [38]:
def delete_key(k, f):
    """
    Delete a key from the dataset file at file f.
    """
    dataset_info = torch.load(f)
    del dataset_info[k]
    torch.save(dataset_info, f)
    
def rename_key(k, k1, f):
    """
    Rename a key from the dataset file at file f.
    """
    dataset_info = torch.load(f)
    dataset_info[k1] = dataset_info[k]
    del dataset_info[k]
    torch.save(dataset_info, f)

In [39]:
for i in range(root_level, root_level+1):
    if use_root: 
        l = root_list
    else:
        l = hier.get_nodes_at_level(i)
    for n, k in enumerate(sorted(l)):
        superclasses, subclass_split, label_map = DG.get_superclasses(level=level,
                      Nsubclasses=sub_per_super,
                      split=None, 
                      ancestor=k, 
                      balanced=True) # fix subclass/superclass (adjust for long tail)
        if len(subclass_split[0]) >= n_superclasses:
            # Create reduced splits
            if save:
                r_idx = sorted(random.sample([t for t in range(len(subclass_split[0]))], n_superclasses))
                subclass_red = []
                superclasses_red = []
                label_map_tmp = {}
                for idx in r_idx:
                    subclass_red.append(subclass_split[0][idx])
                    superclasses_red.append(superclasses[idx])
                    label_map_tmp[idx] = label_map[idx]

                # Adjust label map
                ct = 0
                label_map_red = {}
                for key in sorted(label_map_tmp.keys()):
                    label_map_red[ct] = label_map_tmp[key]
                    ct += 1
                    
                # Store cstm{id}_r{root level}_l{superclass level}
                dataset_info[f'cstm{n}_r{i}_l{level}'] = {'n_superclasses': len(r_idx), 
                                                          'n_subclasses': len(r_idx) * sub_per_super,
                                                          'superclasses': superclasses_red,
                                                          'label_map': label_map_red,
                                                          'subclass_split': (subclass_red, []),
                                                          'root': k}
            else:
                print(i, k)

3 n01861778
3 n02671780
3 n03574816
3 n04341686


In [41]:
if save:
    torch.save(dataset_info, './files/dataset_info.pt') 

dict_keys(['ds7_r2_l5', 'ds0_r0_l5', 'ds1_r0_l5', 'ds2_r0_l5', 'ds3_r0_l5', 'ds4_r0_l5', 'ds5_r2_l5', 'ds6_r2_l5', 'ds8_r3_l5', 'ds9_r3_l5', 'ds10_r3_l5', 'ds11_r3_l5'])