In [None]:
from graphviz import Digraph

# Common parameters and styles
rankdir = 'TB'
nodesep = '0.1'
ranksep = '0.6'
fontname = 'Helvetica'
fontsize = '10'
edge_fontsize = '9'

# Colors
color_all_data = 'khaki'
color_institution = '#AED6F1'
color_training = 'darkseagreen'
color_preholdout = 'lightsalmon'
color_holdout = 'thistle'

# Common node styles
node_style_all_data = {'shape': 'rect', 'style': 'filled,rounded', 'fillcolor': color_all_data, 'fontname': fontname, 'fontsize': fontsize}
node_style_institution = {'shape': 'rect', 'style': 'filled,rounded', 'fillcolor': color_institution, 'fontname': fontname, 'fontsize': fontsize}
node_style_training = {'shape': 'rect', 'style': 'filled,rounded', 'fillcolor': color_training, 'fontname': fontname, 'fontsize': fontsize}
node_style_preholdout = {'shape': 'rect', 'style': 'filled,rounded', 'fillcolor': color_preholdout, 'fontname': fontname, 'fontsize': fontsize}
node_style_holdout = {'shape': 'rect', 'style': 'filled,rounded', 'fillcolor': color_holdout, 'fontname': fontname, 'fontsize': fontsize}
node_style_spacer = {'shape': 'point', 'width': '2', 'height': '0', 'style': 'invis'}

# Common edge styles
edge_style_all_data = {'fontsize': edge_fontsize, 'fontname': fontname, 'color': color_all_data}
edge_style_training = {'fontsize': edge_fontsize, 'fontname': fontname, 'color': color_training}
edge_style_preholdout = {'fontsize': edge_fontsize, 'fontname': fontname, 'color': 'salmon'}
edge_style_holdout = {'fontsize': edge_fontsize, 'fontname': fontname, 'color': color_holdout, 'style': 'dashed'}
edge_style_institution_yellow = {'fontsize': edge_fontsize, 'fontname': fontname, 'color': color_all_data}
edge_style_institution_green = {'fontsize': edge_fontsize, 'fontname': fontname, 'color': 'mediumseagreen'}

# Institutions (common)
institutions = ['UC', 'UM', 'OSU', 'MUSC', 'MD', 'UL']
institution_labels = {
    'UC': 'UC',
    'UM': 'U-M',
    'OSU': 'OSU',
    'MUSC': 'MUSC',
    'MD': 'MD Anderson',
    'UL': 'UofL',
}

def create_prehodout_diagram():
    dot = Digraph(comment='Sample Processing Diagram')
    dot.attr(rankdir=rankdir, nodesep=nodesep, ranksep=ranksep)
    
    # Top node
    dot.node('ALL', 'All Data', **node_style_all_data)
    
    # Institution nodes
    for inst in institutions:
        dot.node(inst, institution_labels[inst], **node_style_institution)
    
    # Group All Data on top row
    with dot.subgraph() as s:
        s.attr(rank='same')
        s.node('ALL')
    
    # Group institution nodes same rank (below ALL)
    with dot.subgraph() as s:
        s.attr(rank='same')
        for inst in institutions:
            s.node(inst)
    
    # Connect All Data to institutions
    for inst in institutions:
        dot.edge('ALL', inst, **edge_style_all_data)
    
    # Processing nodes
    dot.node('CV', 'Training Set\n(Batch Factors Learned)', **node_style_training)
    dot.node('Spacer', '', **node_style_spacer)
    dot.node('Pre', 'Pre-Holdout Set\n(Batch Factors Applied,\nLabel Blind,\nRandomly Selected)', **node_style_preholdout)
    
    # Group processing nodes on same rank
    with dot.subgraph() as s:
        s.attr(rank='same')
        s.node('CV')
        s.node('Spacer')
        s.node('Pre')
    
    # Invisible edges for spacing
    dot.edge('CV', 'Spacer', style='invis')
    dot.edge('Spacer', 'Pre', style='invis')
    
    # Connect institutions to processing nodes
    for inst in institutions:
        dot.edge(inst, 'CV', **edge_style_training)
        dot.edge(inst, 'Pre', **edge_style_preholdout)
    
    # Render
    output_path = 'F4.PREHOLDOUT'
    dot.render(output_path, format='pdf', cleanup=False)
    print(f'Graph rendered to {output_path}.pdf')

def create_institute_holdout_diagram():
    dot = Digraph(comment='Institute Holdout Diagram')
    dot.attr(rankdir=rankdir, nodesep=nodesep, ranksep=ranksep)
    
    # Top node: Analysis Set
    dot.node('ALL', 'Analysis Set', **node_style_all_data)
    
    # Institution nodes and edges
    for inst in institutions:
        dot.node(inst, institution_labels[inst], **node_style_institution)
        dot.edge('ALL', inst, **edge_style_institution_yellow)
    
    # Downstream analysis nodes
    dot.node('CV', 'Training Set', **node_style_training)
    dot.node('Combo', 'Single Institute Hold-Out', **node_style_holdout)
    
    # Connect institutions to analysis nodes
    for inst in institutions:
        dot.edge(inst, 'CV', **edge_style_institution_green)
        dot.edge(inst, 'Combo', **edge_style_holdout)
    
    # Horizontal alignment for institutions
    with dot.subgraph() as s:
        s.attr(rank='same')
        for inst in institutions:
            s.node(inst)
    
    # Horizontal alignment for analysis nodes
    with dot.subgraph() as s:
        s.attr(rank='same')
        s.node('CV')
        s.node('Combo')
    
    # Render
    output_path = 'F4.INSTITUTEHOLDOUT'
    dot.render(output_path, format='pdf', cleanup=False)
    print(f'Graph rendered to {output_path}.pdf')

def create_crossvalidation_diagram():
    dot = Digraph(comment='Cross-Validation Diagram')
    dot.attr(rankdir=rankdir, nodesep=nodesep, ranksep=ranksep)
    
    # Top node: Analysis Set
    dot.node('ALL', 'Analysis Set', **node_style_all_data)
    
    # Institution nodes and edges
    for inst in institutions:
        dot.node(inst, institution_labels[inst], **node_style_institution)
        dot.edge('ALL', inst, **edge_style_all_data)
    
    # Add 10-Fold Cross Validation node
    dot.node('CV10', '10-Fold Cross Validation', **node_style_training)
    
    # Connect institutions to 10-Fold CV node
    for inst in institutions:
        dot.edge(inst, 'CV10', **edge_style_training)
    
    # Horizontal alignment for institutions
    with dot.subgraph() as s:
        s.attr(rank='same')
        for inst in institutions:
            s.node(inst)
    
    # Render
    output_path = 'F4.CROSSVALIDATION'
    dot.render(output_path, format='pdf', cleanup=False)
    print(f'Graph rendered to {output_path}.pdf')

# Call to create all diagrams
create_prehodout_diagram()
create_institute_holdout_diagram()
create_crossvalidation_diagram()