In [1]:
## please delete below code after convertion in converted script(py) file
## + 필요없는 내용 삭제(초반부 1,3~14열, In[ ]형태의 주석제거)
!jupyter nbconvert --to script genotypes.ipynb
!sed -i '/^#[ ]In\[/d' genotypes.py
!sed -i -e '1d;3,14d' genotypes.py

[NbConvertApp] Converting notebook genotypes.ipynb to script
[NbConvertApp] Writing 2332 bytes to genotypes.py


In [None]:
from collections import namedtuple
import torch
import torch.nn as nn
from models import ops

In [None]:
Genotype = namedtuple('Genotype','normal normal_concat reduce reduce_concat')

PRIMITIVES = [
    'max_pool_3x3',
    'avg_pool_3x3',
    'skip_connect', # identity
    'sep_conv_3x3',
    'sep_conv_5x5',
    'dil_conv_3x3',
    'dil_conv_5x5',
    'none'
]

In [3]:
def to_dag(C_in, gene, reduction):
    """generate discrete ops from gene"""
    dag = nn.ModuleList()
    for edges in gene:
        row = nn.ModuleList()
        for op_name, s_idx in edges:
            
            stride = 2 if reduction and s_idx < 2 else 1
            op = ops.OPS[op_name](C_in, stride, True)
            if not isinstance(op, ops.Identity):
                op = nn.Sequential(
                    op,
                    ops.DropPath_()
                )
            op.s_idx = s_idx
            row.append(op)
        dag.append(row)
    return dag

In [None]:
def from_str(s):
    genotype = eval(s)
    return genotype


def parse(alpha, k):
    """
    parse continuous alpha to discrete gene.
    alpha is ParameterList:
    ParameterList [
        Parameter(n_edges1, n_ops),
        Parameter(n_edges2, n_ops),
        ...
    ]

    gene is list:
    [
        [('node1_ops_1', node_idx), ..., ('node1_ops_k', node_idx)],
        [('node2_ops_1', node_idx), ..., ('node2_ops_k', node_idx)],
        ...
    ]
    each node has two edges (k=2) in CNN.
    """
    gene = []
    assert PRIMITIVES[-1] == 'none'
    
    for edges in alpha:
        edge_max, primitive_indices = torch.topk(edges[:, :-1], 1)
        topk_edge_values, topk_edge_indices = torch.topk(edge_max.view(-1), k)
        node_gene = []
        for edge_idx in topk_edge_indices:
            prim_idx = primitive_indices[edge_idx]
            prim = PRIMITIVES[prim_idx]
            node_gene.append((prim, edge_idx.item()))
            
        gene.append(node_gene)
        
    return gene