In [119]:
import numpy as np
import copy

import torch
import torch.nn as nn

In [115]:
PRIMITIVES = [
    "max_pool_3x3",
    "avg_pool_3x3",
    "skip_connect",  # identity
    "conv_1x5_5x1",
    "conv_3x3",
    "sep_conv_3x3",
    "dil_conv_3x3"
]

switches_normal = eval("[[False, True, True, False, False, True, False], [False, False, True, True, True, False, False], [True, False, True, True, False, False, False], [False, True, True, True, False, False, False], [True, True, True, False, False, False, False], [True, True, True, False, False, False, False], [True, False, False, False, True, True, False], [False, True, True, True, False, False, False], [True, True, True, False, False, False, False]]")
switches_reduce = eval("[[True, True, False, False, False, False, False], [True, False, False, True, False, False, False], [True, True, False, False, False, False, False], [False, True, False, False, True, False, False], [True, False, True, False, False, False, False], [False, True, False, False, True, False, False], [False, True, False, True, False, False, False], [False, False, False, True, True, False, False], [True, True, False, False, False, False, False]]")

alpha = [torch.tensor(np.random.normal(0, 0.4, size=(2,3))),
         torch.tensor(np.random.normal(0, 0.4, size=(3,3))), 
         torch.tensor(np.random.normal(0, 0.4, size=(4,3)))]

alpha_pairwise = [torch.tensor([1]), 
                  torch.tensor([0.3346, 0., 0.]), 
                  torch.tensor([0.0, 0.0, 0.0, 0.1686, 0.0, 0.0])]

In [123]:
def convert_tensor_alphas(alpha_concat, nodes=3):
    alphas = []
    for a_i in get_edge_indices(nodes):
        alphas.append(
            torch.Tensor(alpha_concat[a_i[0]:a_i[1]]))
        
    # print(alpha_concat, alphas)
    return alphas

def get_edge_indices(nodes=3):
    # Amount of nodes for each edge
    j = [i for i in range(2, nodes+2)] 
    
    prev = 0
    indices = []
    for i in j:
        if prev != 0:
            indices.append((sum(j[:j.index(prev)+1]), 
                            sum(j[:j.index(i)+1])))
        else:
            indices.append((0, sum(j[:j.index(i)+1])))
        prev = i
    return indices

def parse(alpha, switches, k, primitives=PRIMITIVES):
    gene = []
    j = 0
    
    for edge_i, edges in enumerate(alpha):
        # These primitive indices don't correspond to the actual
        # primitive. k=1 here.
        edge_max, primitive_indices = torch.topk(edges[:, :], 1)
        topk_edge_values, topk_edge_indices = torch.topk(edge_max.view(-1), k)
        
        # Primitive indices which are enabled
        primitives_enabled = []
        for _ in range(len(edges)):
            prim_enabled_indices = np.where(switches[j])[0]
            # The primitive operations which are enabled
            primitives_enabled.append([
                primitives[i] for i in prim_enabled_indices])
            j += 1
        
        # For each edge the highest alpha primitive indice
        node_gene = []
        for edge_idx in topk_edge_indices:
            prim_idx = primitive_indices[edge_idx]
            # print(edges)
            # print(primitives_enabled)
            # print(edge_idx, prim_idx[0])
            # print(primitives_enabled[edge_idx][prim_idx[0]])
            prim = primitives_enabled[edge_idx][prim_idx[0]]
            node_gene.append((prim, edge_idx.item()))

        gene.append(node_gene)
    return gene

parse(alpha, switches_normal, k=2)

[[('skip_connect', 1), ('sep_conv_3x3', 0)],
 [('skip_connect', 2), ('conv_1x5_5x1', 1)],
 [('avg_pool_3x3', 2), ('skip_connect', 0)]]

In [121]:
# gene = [[('sep_conv_3x3', 0), ('conv_1x5_5x1', 1)],
#  [('max_pool_3x3', 0), ('skip_connect', 2)],
#  [('skip_connect', 0), ('skip_connect', 2)]]
[['max_pool_3x3', 'skip_connect', 'conv_1x5_5x1'], 
 ['avg_pool_3x3', 'conv_1x5_5x1'], 
 ['max_pool_3x3', 'avg_pool_3x3', 'skip_connect']]

[['max_pool_3x3', 'skip_connect', 'conv_1x5_5x1'],
 ['avg_pool_3x3', 'conv_1x5_5x1'],
 ['max_pool_3x3', 'avg_pool_3x3', 'skip_connect']]

In [122]:

def limit_skip_connections(alphas, switches, num_of_sk=2, nodes=3, 
                           k=2, primitives=PRIMITIVES):
    sk_idx = primitives.index("skip_connect")
    alpha_concat = np.concatenate(alphas, axis=0)
    
    # skip-connections alpha indices
    # edge index, skip-connection alpha_index
    sk_indices = []
    # alphas corresponding to the skip-connections
    sk_alphas = []
    for i, sw in enumerate(switches):
        prim_indices = np.where(sw)[0]
        # skip-connection index of alpha
        sk_index = np.where(prim_indices==sk_idx)[0].tolist()
        sk_indices.append([i, sk_index])

        if len(sk_index) > 0:
            sk_alphas.append(alpha_concat[i][sk_index][0])
        else: # If the skip-connection is not enabled, set to infinity.
            sk_alphas.append(float("inf"))
        
    # Number of skip-connections enabled in switches
    # TODO: refactor to check based on gene
    # num_sk_enabled = sum(np.array(switches)[:, sk_idx])
    gene = parse(alphas, switches, k=k)
    num_sk_enabled = sum([1 for edge in gene 
                          for op in edge if op[0] == "skip_connect"])

    sk_a = np.array(sk_alphas)
    
    if num_sk_enabled < num_of_sk:
        alphas = convert_tensor_alphas(alpha_concat)
        gene = parse(alphas, switches, k=k)
        return gene
    else:
        it = 0
        while num_sk_enabled > num_of_sk:
            print("########## iteration", it)
            it += 1
            # Pick skip-connection index with lowest alpha 
            # value
            idx = np.argmin(sk_a)
            sk_a[idx] = float("inf")
            
            # row index and alpha index
            row_idx, alpha_idx = sk_indices[idx][0], sk_indices[idx][1][0]
            
            # Set switch sk index to False
            # switches[row_idx][sk_idx] = False
        
            # set alphas to -inf to make sure, prevent it from
            # being picked. 
            alpha_concat[row_idx][alpha_idx] = float("-inf")
            alphas = convert_tensor_alphas(alpha_concat)
            
            gene = parse(alphas, switches, k=k)
            num_sk_enabled = sum([1 for edge in gene 
                                  for op in edge if op[0] == "skip_connect"])
            
            if num_sk_enabled <= num_of_sk:
                # return the new switches
                return gene

limit_skip_connections(alpha, switches_normal)

tensor([[-0.2967, -0.2372,  0.2201],
        [ 0.2703, -0.5062,  0.0184]], dtype=torch.float64)
[['avg_pool_3x3', 'skip_connect', 'sep_conv_3x3'], ['skip_connect', 'conv_1x5_5x1', 'conv_3x3']]
tensor(1) tensor(0)
skip_connect
tensor([[-0.2967, -0.2372,  0.2201],
        [ 0.2703, -0.5062,  0.0184]], dtype=torch.float64)
[['avg_pool_3x3', 'skip_connect', 'sep_conv_3x3'], ['skip_connect', 'conv_1x5_5x1', 'conv_3x3']]
tensor(0) tensor(2)
sep_conv_3x3
tensor([[-0.4267, -0.0180, -0.4179],
        [-0.0373, -0.1254,  0.0021],
        [ 0.2339, -0.3696,  0.5404]], dtype=torch.float64)
[['max_pool_3x3', 'skip_connect', 'conv_1x5_5x1'], ['avg_pool_3x3', 'skip_connect', 'conv_1x5_5x1'], ['max_pool_3x3', 'avg_pool_3x3', 'skip_connect']]
tensor(2) tensor(2)
skip_connect
tensor([[-0.4267, -0.0180, -0.4179],
        [-0.0373, -0.1254,  0.0021],
        [ 0.2339, -0.3696,  0.5404]], dtype=torch.float64)
[['max_pool_3x3', 'skip_connect', 'conv_1x5_5x1'], ['avg_pool_3x3', 'skip_connect', 'conv_1x5_5x1'

[[('sep_conv_3x3', 0), ('conv_3x3', 1)],
 [('skip_connect', 2), ('conv_1x5_5x1', 1)],
 [('avg_pool_3x3', 2), ('skip_connect', 0)]]

In [None]:
gene = []
k = 2
j = 0

for edge_i, edges in enumerate(alpha):
    # TODO: These primitive indices don't correspond to the actual
    # primitive. k=1 here.
    edge_max, _ = torch.topk(edges[:, :], 1)
    
    topk_edge_values, topk_edge_indices = torch.topk(edge_max.view(-1), k)
    
    # Primitive indices which are enabled
    primitives_enabled = []
    for _ in range(len(edges)):
        prim_indices = np.where(switches[j])[0]
        # The primitive operations which are enabled
        primitives_enabled.append([primitives[i] for i in prim_indices])
        j += 1

    # For each edge the highest alpha primitive indice
    node_gene = []
    for edge_idx in topk_edge_indices:
        prim = primitives_enabled[edge_idx][prim_idx]
        node_gene.append((prim, edge_idx.item()))
        
    gene.append(node_gene)

gene