# Code to automatically visualize a Genotype
At the moment only works for DAGs with the restriction of 2 incoming edges per node

In [13]:
import graphviz
from collections import namedtuple

## Definition of the genotype that should get visualized

In [4]:
Genotype = namedtuple("Genotype", "normal normal_concat reduce reduce_concat")
genotype = Genotype(
    normal=[
        ('sep_conv_3x3', 1),
        ('sep_conv_5x5', 0),
        ('sep_conv_3x3', 2),
        ('sep_conv_5x5', 1),
        ('sep_conv_3x3', 1),
        ('sep_conv_3x3', 2),
        ('skip_connect', 3),
        ('dil_conv_3x3', 0)],
    normal_concat=range(2,6),
    reduce=[
        ('max_pool_3x3', 1),
        ('sep_conv_3x3', 0),
        ('max_pool_3x3', 2),
        ('sep_conv_5x5', 1),
        ('sep_conv_5x5', 1),
        ('max_pool_3x3', 3),
        ('sep_conv_5x5', 1),
        ('sep_conv_5x5', 3)],
    reduce_concat=range(2,6)
)

## Definition of some constants

In [6]:
incoming_edges_per_node = 2

## The code

In [49]:
def __add_nodes(dag, number_nodes):
    """Adds the given number of nodes to the given DAG.
    
    Args:
        dag (Digraph): The DAG where nodes should be added to.
        number_nodes (int): The number of nodes that should get added to the DAG
    """
    for i in range(number_nodes):
        if i in [0, 1]:
            dag.node(str(i), str(i), pos=f"0,{i}!", pin="true")
        else:
             dag.node(str(i), str(i))
        
        
def __add_edges(dag, edges, concat, output_node):
    """Adds the given edges to the given DAG.
    
    Args:
        dag (Digraph): The DAG where edges should be added to.
        edges (list of tuple of (str, int)): List containing all edges that should be added to the DAG.
        concat (list of int): List containing the nodes who's output should get combined to form the output.
        output_node (int): Index of the output node.
    """
    edge_counter = 0
    current_node = 2 # nodes 0 and 1 are input nodes
    for edge in edges:
        edge_counter += 1
        dag.edge(str(edge[1]), str(current_node), label=edge[0])
        if edge_counter % incoming_edges_per_node == 0:
            current_node += 1
    
    # edges to output node
    for node in concat:
        dag.edge(str(node), str(output_node))
        

def visualize_genotype(genotype):
    """Creates two DAG plots, one for the normal cell and one for the reduce cell.
    
    Args:
        genotype (namedtuple): Contains the genotype information. See gaea code.
        
    Returns:
        tuple of Digraph: Tuple containing the normal cell and reduction cell digraphs
    """
    
    dag_normal = graphviz.Digraph("Normal Cell", engine="neato", format="png")
    dag_normal.attr('node', shape='circle')
    dag_reduce = graphviz.Digraph("Reduce Cell", engine="neato", format="png")
    dag_reduce.attr('node', shape='circle')
    # create nodes, 3 corresponds to 2 input nodes and 1 output node
    nodes_normal = int(len(genotype.normal)/incoming_edges_per_node + 3)
    nodes_reduce = int(len(genotype.reduce)/incoming_edges_per_node + 3)
    __add_nodes(dag_normal, nodes_normal)
    __add_nodes(dag_reduce, nodes_reduce)
        
    # create edges
    __add_edges(dag_normal, genotype.normal, genotype.normal_concat, nodes_normal-1)
    __add_edges(dag_reduce, genotype.reduce, genotype.reduce_concat, nodes_reduce-1)
    
    return (dag_normal, dag_reduce)

## Test

In [50]:
cell_normal, cell_reduce = visualize_genotype(genotype)

In [57]:
graphviz.Source(cell_normal)
graphviz.Source(cell_normal).render('normal_cell')

'normal_cell.pdf'

In [58]:
graphviz.Source(cell_reduce)
graphviz.Source(cell_reduce).render("reduce_cell")

'reduce_cell.pdf'