In [34]:
import graphviz

def generate_unet_diagram(output_file='unet_model_diagram'):
    dot = graphviz.Digraph(format='png')
    dot.attr(rankdir='TB', size="8,8!", nodesep='1.5', ranksep='0.3')

    styles = {
        'input': {'shape': 'box', 'style': 'filled', 'color': '#D9EAD3', 'penwidth': '2'},
        'conv': {'shape': 'box', 'style': 'filled', 'color': '#ACD8E5', 'penwidth': '2'},
        'pool': {'shape': 'box', 'style': 'filled', 'color': '#F7C6C7', 'penwidth': '2'},
        'up': {'shape': 'box', 'style': 'filled', 'color': '#D5E8D4', 'penwidth': '2'},
        'concat': {'shape': 'box', 'style': 'filled', 'color': '#FFF2CC', 'penwidth': '2'},
        'bottleneck': {'shape': 'box', 'style': 'filled', 'color': '#D5A6BD', 'penwidth': '2'},
        'output': {'shape': 'box', 'style': 'filled', 'color': '#C9DAF8', 'penwidth': '2'},
    }

    def add_layer(layer_id, label, style):
        dot.node(layer_id, label, **styles[style])

    def add_edge(from_node, to_node):
        dot.edge(from_node, to_node)

    input_size = (512, 512, 3)
    output_size = (512, 512, 1)
    input_id = 'input'
    add_layer(input_id, f"Input\n{input_size}", 'input')

    num_filters = 16
    depth = 2
    dropout = 0.5
    batch_norm = True
    conv_blocks = []
    
    prev_layer = input_id

    # Contracting path
    for i in range(depth):
        conv_block_id = f"conv_block_{i}"
        add_layer(conv_block_id, f"SeparableConv2D\n{num_filters * (2**i)} filters", 'conv')
        add_edge(prev_layer, conv_block_id)
        conv_blocks.append(conv_block_id)
        
        pool_id = f"pool_{i}"
        add_layer(pool_id, "MaxPooling2D", 'pool')
        add_edge(conv_block_id, pool_id)
        
        prev_layer = pool_id

    # Bottleneck
    bottleneck_id = "bottleneck"
    add_layer(bottleneck_id, f"Bottleneck\nSeparableConv2D\n{num_filters * (2**depth)} filters", 'bottleneck')
    add_edge(prev_layer, bottleneck_id)
    prev_layer = bottleneck_id

    # Expansive path
    for i in reversed(range(depth)):
        upsample_id = f"upsample_{i}"
        add_layer(upsample_id, "UpSampling2D", 'up')
        add_edge(prev_layer, upsample_id)
        
        concat_id = f"concat_{i}"
        add_layer(concat_id, "Concatenate", 'concat')
        add_edge(upsample_id, concat_id)
        add_edge(conv_blocks[i], concat_id)
        
        conv_block_id = f"conv_block_up_{i}"
        add_layer(conv_block_id, f"SeparableConv2D\n{num_filters * (2**i)} filters", 'conv')
        add_edge(concat_id, conv_block_id)
        
        prev_layer = conv_block_id

    # Final layer
    output_id = "output"
    add_layer(output_id, f"Output\nConv2D, 1 filter", 'output')
    add_edge(prev_layer, output_id)

    dot.attr(fontname='Helvetica')
    dot.render(output_file)
    print(f"Diagram saved as {output_file}.png")

generate_unet_diagram()

Diagram saved as unet_model_diagram.png
