In [2]:
from graphviz import Digraph

def create_ml_pipeline_diagram():
    # Output tiff for journal submission; set dpi to 600
    dot = Digraph(comment='ML Pipeline', format='tiff')
    dot.attr(rankdir='TB', splines='ortho', nodesep='0.8', ranksep='0.8')
    dot.attr('graph', dpi='600')

    # Default node style settings
    dot.attr(
        'node',
        shape='box',
        style='rounded,filled',
        fillcolor='#E3F2FD',
        fontname='Helvetica',
        fontsize='10'
    )

    # Phase 1: Model development & internal validation
    with dot.subgraph(name='cluster_dev') as c:
        c.attr(
            label='Model development and internal validation',
            style='dashed',
            color='#1565C0',
            fontcolor='#1565C0'
        )

        c.node(
            'Deriv',
            'Derivation cohort\n'
            'Stage III colon adenocarcinoma\n'
            'n = 331',
            shape='cylinder',
            fillcolor='#FFECB3'
        )

        c.node(
            'Preproc',
            'Preprocessing\n'
            '(KNN imputation, one-hot encoding,\n'
            'log CEA, variable formatting)'
        )

        c.node(
            'Nested',
            'Nested cross-validation\n'
            '(5×3-fold XGBoost,\n'
            'class-imbalance handling)'
        )

        c.node(
            'Calib',
            'Probability calibration\n'
            '(isotonic regression)'
        )

        c.node(
            'Perf',
            'Internal performance\n'
            '(AUC-ROC, Brier score,\n'
            'calibration & decision curves)'
        )

        c.edge('Deriv', 'Preproc')
        c.edge('Preproc', 'Nested')
        c.edge('Nested', 'Calib')
        c.edge('Calib', 'Perf')

    # Phase 2: Final model & DFS risk stratification
    with dot.subgraph(name='cluster_final') as c:
        c.attr(
            label='Final model and DFS risk stratification',
            style='dashed',
            color='#2E7D32',
            fontcolor='#2E7D32'
        )

        c.node(
            'FinalModel',
            'Final calibrated model\n'
            '(4 predictors: AJCC substage,\n'
            'LNR, PNI, differentiation)',
            shape='doubleoctagon',
            fillcolor='#C8E6C9',
            penwidth='2'
        )

        c.node(
            'Cutoff',
            'Risk threshold\n'
            '(Youden index for EDR-18)'
        )

        c.node(
            'RiskGroups',
            'ML high- vs low-risk groups\n'
            '(based on EDR-18 probability)'
        )

        c.node(
            'KM',
            'DFS stratification\n'
            'Kaplan–Meier curves,\n'
            'log-rank tests (overall & IIIB)'
        )

        # Link Phase 1 to Phase 2
        c.edge('Perf', 'FinalModel')

        c.edge('FinalModel', 'Cutoff')
        c.edge('Cutoff', 'RiskGroups')
        c.edge('RiskGroups', 'KM')


    # Phase 3: External validation & deployment
    with dot.subgraph(name='cluster_ext') as c:
        c.attr(
            label='External validation and deployment',
            style='dashed',
            color='#C62828',
            fontcolor='#C62828'
        )

        c.node(
            'ExtCohort',
            'External cohort\n'
            'Stage III colon adenocarcinoma\n'
            'n = 144',
            shape='cylinder',
            fillcolor='#FFECB3'
        )

        c.node(
            'ExtPrep',
            'Same preprocessing\n'
            '(KNN imputation, encoding)'
        )

        c.node(
            'ExtEval',
            'External performance\n'
            '(AUC-ROC, Brier score,\n'
            'calibration plot)'
        )

        c.node(
            'WebApp',
            'Web-based risk calculator\n'
            '(Streamlit, pre-trained model)',
            shape='component',
            fillcolor='#E1BEE7'
        )

        # External validation flow
        c.edge('ExtCohort', 'ExtPrep')
        c.edge('ExtPrep', 'ExtEval')

        # Connect Final model to external validation & web app
        c.edge('FinalModel', 'ExtPrep')
        c.edge('FinalModel', 'WebApp')

    return dot

# Generate and save 600 dpi TIFF
ml_diag = create_ml_pipeline_diagram()
ml_diag.render('Figure_S1_ML_Pipeline_Flowchart', view=True, cleanup=True)
print("ML pipeline flowchart generated: Figure_S1_ML_Pipeline_Flowchart.tiff")

ML pipeline flowchart generated: Figure_S1_ML_Pipeline_Flowchart.tiff
