In [2]:
from graphviz import Digraph

# Create a new directed graph
dot = Digraph(comment='MMoE Architecture')
dot.attr(rankdir='TB', newrank='true')  # Top-to-Bottom layout

# --- 1. Input & Backbone (Frozen) ---
with dot.subgraph(name='cluster_input') as c:
    c.attr(label='Input & Frozen Backbone', style='filled', color='lightgrey')
    c.node('input', 'Input Image', shape='box')
    c.node('backbone', 'Frozen DINOv3 (ViT-L)', shape='box')
    c.node('features', 'Features (x) [B, D]', shape='box', style='dashed')
    c.edge('input', 'backbone')
    c.edge('backbone', 'features')

# --- 2. MMoE Layer (Trainable) ---
with dot.subgraph(name='cluster_mmoe') as c:
    c.attr(label='Trainable MMoE Layer', style='filled', color='lightyellow')
    
    # Experts
    with dot.subgraph(name='cluster_experts') as e:
        e.attr(label='Experts', rank='same')
        e.node('e1', 'Expert 1 (MLP)', shape='box')
        e.node('e_dots', '...', shape='plaintext')
        e.node('eN', 'Expert E (MLP)', shape='box')
    
    # Gates
    with dot.subgraph(name='cluster_gates') as g:
        g.attr(label='Gates (1 per task)', rank='same')
        g.node('g_family', 'Gate_Family (Softmax)', shape='box')
        g.node('g_order', 'Gate_Order (Softmax)', shape='box')
        g.node('g_habitat', 'Gate_Habitat (Softmax)', shape='box')
        g.node('g_troph', 'Gate_Troph (Softmax)', shape='box')

    # Expert Outputs (invisible nodes for layout)
    dot.node('h1', '', shape='point', width='0')
    dot.node('hN', '', shape='point', width='0')
    
    # Connections from features to experts and gates
    dot.edge('features', 'e1', lhead='cluster_experts')
    dot.edge('features', 'eN', lhead='cluster_experts')
    dot.edge('features', 'g_family', lhead='cluster_gates')
    dot.edge('features', 'g_troph', lhead='cluster_gates') # All gates connect to features
    
    # Show expert outputs
    dot.edge('e1', 'h1')
    dot.edge('eN', 'hN')


# --- 3. Towers & Outputs (Trainable) ---
with dot.subgraph(name='cluster_output') as c:
    c.attr(label='Trainable Task Towers')
    
    # Invisible nodes for weighted sum
    dot.node('mix_family', 'Weighted Sum', shape='circle')
    dot.node('mix_order', 'Weighted Sum', shape='circle')
    dot.node('mix_habitat', 'Weighted Sum', shape='circle')
    dot.node('mix_troph', 'Weighted Sum', shape='circle')

    # Final Tower Heads
    dot.node('t_family', 'Tower_Family (MLP)', shape='box')
    dot.node('t_order', 'Tower_Order (MLP)', shape='box')
    dot.node('t_habitat', 'Tower_Habitat (MLP)', shape='box')
    dot.node('t_troph', 'Tower_Troph (MLP)', shape='box')

    # Final Outputs
    dot.node('out_family', 'Family Logits', shape='box', style='dashed')
    dot.node('out_order', 'Order Logits', shape='box', style='dashed')
    dot.node('out_habitat', 'Habitat Logits', shape='box', style='dashed')
    dot.node('out_troph', 'Troph Prediction', shape='box', style='dashed')
    
    # Connections for one task (Family)
    dot.edge('h1', 'mix_family', label='w_fam_1', style='dotted', arrowhead='none')
    dot.edge('hN', 'mix_family', label='w_fam_E', style='dotted', arrowhead='none')
    dot.edge('g_family', 'mix_family', style='dashed')
    dot.edge('mix_family', 't_family')
    dot.edge('t_family', 'out_family')

    # Connections for Order
    dot.edge('h1', 'mix_order', style='dotted', arrowhead='none')
    dot.edge('hN', 'mix_order', style='dotted', arrowhead='none')
    dot.edge('g_order', 'mix_order', style='dashed')
    dot.edge('mix_order', 't_order')
    dot.edge('t_order', 'out_order')

    # Connections for Habitat
    dot.edge('h1', 'mix_habitat', style='dotted', arrowhead='none')
    dot.edge('hN', 'mix_habitat', style='dotted', arrowhead='none')
    dot.edge('g_habitat', 'mix_habitat', style='dashed')
    dot.edge('mix_habitat', 't_habitat')
    dot.edge('t_habitat', 'out_habitat')
    
    # Connections for Troph
    dot.edge('h1', 'mix_troph', style='dotted', arrowhead='none')
    dot.edge('hN', 'mix_troph', style='dotted', arrowhead='none')
    dot.edge('g_troph', 'mix_troph', style='dashed')
    dot.edge('mix_troph', 't_troph')
    dot.edge('t_troph', 'out_troph')


# Render the graph
# This will save a PDF and a source file, and open the PDF viewer
dot.render('mmoe_architecture', view=False, format='png')

print("Generated mmoe_architecture.png")

Generated mmoe_architecture.png
