In [70]:
bench = 'nb101'
suffix = '_full' if bench != 'nb101' and 'macro' not in bench else '_first'
cfg = f'../zc_combine/configs/{bench}{suffix}.json'
dataset = 'cifar10' if bench != 'tnb101' else 'class_scene'

In [71]:
from zc_combine.utils.script_utils import create_cache_filename
from zc_combine.utils.script_utils import load_feature_proxy_dataset

version_key = 'paper'
cache_path = create_cache_filename('../scripts/cache_data/', cfg, None, version_key, True)

_, data, y = load_feature_proxy_dataset('../data', bench, dataset, cfg=cfg, use_all_proxies=True,
                                        cache_path=cache_path, version_key=version_key)

In [106]:
from zc_combine.fixes.operations import get_ops_edges_nb201, get_ops_edges_tnb101, get_ops_nb101, get_ops_nb301

if bench == 'nb201':
    ops, _ = get_ops_edges_nb201()
elif bench == 'tnb101':    
    ops, _ = get_ops_edges_tnb101()
elif bench == 'nb101':
    ops = get_ops_nb101()
elif bench == 'nb301':
    ops = get_ops_nb301()
else:
    raise ValueError()

print("Loaded: ", ops)

better_op_names = {
    'input': 'input',
    'output': 'output',
    'none': 'zero',
    'skip_connect': 'skip',
    'nor_conv_1x1': 'conv1x1',
    'nor_conv_3x3': 'conv3x3',
    'maxpool3x3': 'maxpool3x3',
    'conv1x1-bn-relu': 'conv1x1',
    'conv3x3-bn-relu': 'conv3x3'
}

ops = [better_op_names[o] for o in ops]
print("More readable: ", ops)

Loaded:  ['input', 'output', 'maxpool3x3', 'conv1x1-bn-relu', 'conv3x3-bn-relu']
More readable:  ['input', 'output', 'maxpool3x3', 'conv1x1', 'conv3x3']


In [107]:
op_map = {str(i): k for i, k in enumerate(ops)}

if bench == 'nb101':
    del op_map['0']
    del op_map['1']  # no input/output node

In [108]:
[c for c in data.columns if 'node' in c]

['node_degree_allowed_(2)_in_degree',
 'node_degree_allowed_(2)_out_degree',
 'node_degree_allowed_(2)_avg_in',
 'node_degree_allowed_(2)_avg_out',
 'node_degree_allowed_(2)_max_out',
 'node_degree_allowed_(2)_max_in',
 'node_degree_allowed_(3)_in_degree',
 'node_degree_allowed_(3)_out_degree',
 'node_degree_allowed_(3)_avg_in',
 'node_degree_allowed_(3)_avg_out',
 'node_degree_allowed_(3)_max_out',
 'node_degree_allowed_(3)_max_in',
 'node_degree_allowed_(4)_in_degree',
 'node_degree_allowed_(4)_out_degree',
 'node_degree_allowed_(4)_avg_in',
 'node_degree_allowed_(4)_avg_out',
 'node_degree_allowed_(4)_max_out',
 'node_degree_allowed_(4)_max_in',
 'node_degree_allowed_(2, 3)_in_degree',
 'node_degree_allowed_(2, 3)_out_degree',
 'node_degree_allowed_(2, 3)_avg_in',
 'node_degree_allowed_(2, 3)_avg_out',
 'node_degree_allowed_(2, 3)_max_out',
 'node_degree_allowed_(2, 3)_max_in',
 'node_degree_allowed_(2, 4)_in_degree',
 'node_degree_allowed_(2, 4)_out_degree',
 'node_degree_allowed_(

In [109]:
c = 'min_path_len_banned_(0)'
if 'min_path' in c:
    opset = eval(c.split('_banned_')[1])
    if isinstance(opset, int):
        opset = (opset,)
    inverse_set = [i for i in op_map.keys() if i not in opset]
    print(inverse_set)

['2', '3', '4']


In [110]:
def node_degree_bench(c, bench):
    if bench in ['tnb101', 'nb201']:
        if 'in_degree' in c:
            return 'Input node degree - '
        elif 'out_degree' in c:
            return 'Output node degree - '
        elif 'avg_in' in c:
            return 'Average outgoing degree - '
        elif 'avg_out' in c:
            return 'Average ingoing degree - '
        else:
            raise ValueError(f"Invalid node degree: {c}")
    if bench in ['nb101']:
        if 'in_degree' in c:
            return 'Output node degree - '
        elif 'out_degree' in c:
            return 'Input node degree - '
        else:
            c = c.split(')_')[1]
            assert c in ['avg_in', 'avg_out', 'max_in', 'max_out'], f"Invalid node degree: {c}"
            what, which = c.split('_')
            return f"{'Average' if what == 'avg' else 'Maximum'} {which}put node degree"


def get_feature_name(c):
    if 'op_count' in c:
        return 'number of '
    elif 'min_path' in c:
        return 'min path over '
    elif 'max_op' in c:
        return 'max path over '
    elif 'node_degree' in c:
        return node_degree_bench(c, bench)
    else:
        raise ValueError()

def to_better_colname(c, op_map):
    feature_name = get_feature_name(c)
    
    if 'min_path' in c:
        opset = eval(c.split('_banned_')[1])
        if isinstance(opset, int):
            opset = (opset,)
            
        opset = [str(o) for o in opset]
        opset = [i for i in op_map.keys() if i not in opset]
    elif 'op_count' in c:
        opset = c.split('_')[-1]
    else:
        opset = c.split('_allowed_')[1]
        if 'node' in c:
            opset = opset.split('_')[0]
        opset = eval(opset)
        
    if isinstance(opset, int):
        opset = [opset]
       
    opset = [op_map[str(o)] for o in opset] 
    opset = f"[{','.join(opset)}]"
        
    return f"{feature_name}{opset}"

In [111]:
new_cols_map = {}

for c in data.columns:    
    if bench == 'nb101' and c in ['op_count_0', 'op_count_1']:
        # skip input/output nodes
        new_cols_map[c] = c
        continue
    
    if bench == 'tnb101' and c == 'op_count_4':
        # included max pooling that's however not there
        new_cols_map[c] = c
        continue
    try:
        new_c = to_better_colname(c, op_map)
        new_cols_map[c] = new_c
    except ValueError:
        print(f'Skipping {c}')
        new_cols_map[c] = c

Skipping epe_nas
Skipping fisher
Skipping flops
Skipping grad_norm
Skipping grasp
Skipping jacov
Skipping l2_norm
Skipping nwot
Skipping params
Skipping plain
Skipping snip
Skipping synflow
Skipping zen
Skipping net


In [112]:
new_cols_map

{'op_count_0': 'op_count_0',
 'op_count_1': 'op_count_1',
 'op_count_2': 'number of [maxpool3x3]',
 'op_count_3': 'number of [conv1x1]',
 'op_count_4': 'number of [conv3x3]',
 'min_path_len_banned_()': 'min path over [maxpool3x3,conv1x1,conv3x3]',
 'min_path_len_banned_(2)': 'min path over [conv1x1,conv3x3]',
 'min_path_len_banned_(3)': 'min path over [maxpool3x3,conv3x3]',
 'min_path_len_banned_(4)': 'min path over [maxpool3x3,conv1x1]',
 'min_path_len_banned_(2, 3)': 'min path over [conv3x3]',
 'min_path_len_banned_(2, 4)': 'min path over [conv1x1]',
 'min_path_len_banned_(3, 4)': 'min path over [maxpool3x3]',
 'max_op_on_path_allowed_(2)': 'max path over [maxpool3x3]',
 'max_op_on_path_allowed_(3)': 'max path over [conv1x1]',
 'max_op_on_path_allowed_(4)': 'max path over [conv3x3]',
 'max_op_on_path_allowed_(2, 3)': 'max path over [maxpool3x3,conv1x1]',
 'max_op_on_path_allowed_(2, 4)': 'max path over [maxpool3x3,conv3x3]',
 'max_op_on_path_allowed_(3, 4)': 'max path over [conv1x1,c