# Creation of functions

Functions used to convert the raw output from the nni library into a format more easily understandable by humans.

In [3]:
def split_normal_reduce(arc):
    """
    Description
    ---------------
    Splits the architecture given as an input into normal and reduce dictionaries

    Input(s)
    ---------------
    arc: dict

    Output(s)
    ---------------
    arc_normal: dict
    arc_reduce: dict
    """
    arc_normal, arc_reduce = dict(),dict()
    for key in arc.keys():
        if "normal" in key:
            arc_normal[key]=arc[key]
        elif "reduce" in key:
            arc_reduce[key]=arc[key]
        else:
            print("Issue encountered : the following value is neither normal nor reduce", key)
    return arc_normal, arc_reduce

def convert_to_simple(arc):
    """
    Description
    ---------------
    Converts the architecture format into a simpler and more easily readable format.
    It also removes all the branches that are not used, based on the "switch" parameters

    Input(s)
    ---------------
    arc: dict

    Output(s)
    ---------------
    kept_arc: dict
    """
    kept_arc_index = []
    kept_arc = dict()
    for value in arc.values():
        if type(value)==type([1,2]):
            kept_arc_index.append(value)
    prev_inc = 0
    increment = 2
    j=2
    for pair in kept_arc_index:
        keys_available = list(arc.keys())[prev_inc:increment]
        values = dict()
        for i in pair:
            values[i]= arc[keys_available[i]]
        kept_arc[j] = values
        prev_inc= increment
        increment += j+1
        j+=1
    
    # print(kept_arc)
    return kept_arc

def split_prep(arc):
  arc_norm,arc_red = split_normal_reduce(arc)
  arc_norm = convert_to_simple(arc_norm)
  arc_red = convert_to_simple(arc_red)
  return arc_norm,arc_red

## Example of usage

input_dict_1 is a real output of the model that can be used.

In [4]:
input_dict_1 = {
    'normal_n2_p0': 'maxpool', 
    'normal_n2_p1': 'maxpool', 
    'normal_n3_p0': 'maxpool', 
    'normal_n3_p1': 'maxpool', 
    'normal_n3_p2': 'maxpool', 
    'normal_n4_p0': 'maxpool', 
    'normal_n4_p1': 'maxpool', 
    'normal_n4_p2': 'maxpool', 
    'normal_n4_p3': 'maxpool', 
    'normal_n5_p0': 'maxpool', 
    'normal_n5_p1': 'maxpool', 
    'normal_n5_p2': 'maxpool', 
    'normal_n5_p3': 'maxpool', 
    'normal_n5_p4': 'maxpool',
    'reduce_n2_p0': 'maxpool', 
    'reduce_n2_p1': 'maxpool', 
    'reduce_n3_p0': 'maxpool', 
    'reduce_n3_p1': 'sepconv5x5', 
    'reduce_n3_p2': 'maxpool', 
    'reduce_n4_p0': 'maxpool', 
    'reduce_n4_p1': 'maxpool', 
    'reduce_n4_p2': 'dilconv5x5', 
    'reduce_n4_p3': 'maxpool', 
    'reduce_n5_p0': 'maxpool', 
    'reduce_n5_p1': 'sepconv5x5', 
    'reduce_n5_p2': 'maxpool', 
    'reduce_n5_p3': 'dilconv5x5', 
    'reduce_n5_p4': 'maxpool', 
    'normal_n2_switch': [1, 0], 
    'normal_n3_switch': [2, 1], 
    'normal_n4_switch': [3, 2], 
    'normal_n5_switch': [2, 4], 
    'reduce_n2_switch': [1, 0], 
    'reduce_n3_switch': [2, 1], 
    'reduce_n4_switch': [3, 2], 
    'reduce_n5_switch': [3, 4] 
}

dict_normal_1, dict_reduce_1 = split_normal_reduce(input_dict_1)
dict_normal_1 = convert_to_simple(dict_normal_1)
dict_reduce_1 = convert_to_simple(dict_reduce_1)

dict_normal_2, dict_reduce_2 = split_prep(input_dict_1)
print(dict_normal_1)
print(dict_normal_2)
print(dict_reduce_1)
print(dict_reduce_2)

{2: {1: 'maxpool', 0: 'maxpool'}, 3: {2: 'maxpool', 1: 'maxpool'}, 4: {3: 'maxpool', 2: 'maxpool'}, 5: {2: 'maxpool', 4: 'maxpool'}}
{2: {1: 'maxpool', 0: 'maxpool'}, 3: {2: 'maxpool', 1: 'maxpool'}, 4: {3: 'maxpool', 2: 'maxpool'}, 5: {2: 'maxpool', 4: 'maxpool'}}
{2: {1: 'maxpool', 0: 'maxpool'}, 3: {2: 'maxpool', 1: 'sepconv5x5'}, 4: {3: 'maxpool', 2: 'dilconv5x5'}, 5: {3: 'dilconv5x5', 4: 'maxpool'}}
{2: {1: 'maxpool', 0: 'maxpool'}, 3: {2: 'maxpool', 1: 'sepconv5x5'}, 4: {3: 'maxpool', 2: 'dilconv5x5'}, 5: {3: 'dilconv5x5', 4: 'maxpool'}}
