In [1]:
from random import randint

In [2]:
from airflow import DAG
from airflow.operators.python import PythonOperator, BranchPythonOperator
from airflow.operators.bash import BashOperator

In [3]:
from datetime import datetime

In [4]:
def _choose_best_model(ti):
    accuracies = ti.xcom_pull(task_ids=[
        "training_model_A",
        "training_model_B",
        "training_model_C"
    ])
    best_accuracy = max(accuracies)
    if (best_accuracy > 8):
        return "accurate"
    return "inaccurate"

In [5]:
def _training_model():
    return randint(1, 10)

In [6]:
with DAG("my_dag", start_date=datetime(2021, 1, 1), schedule_interval="@daily", catchup=False) as dag:
    training_module_A = PythonOperator(
        task_id="training_model_A",
        python_callable=_training_model
    )
    training_module_B = PythonOperator(
        task_id="training_model_B",
        python_callable=_training_model
    )
    training_module_C = PythonOperator(
        task_id="training_model_C",
        python_callable=_training_model
    )
    choose_base_model = BranchPythonOperator(
        task_id = "choose_best_model",
        python_callable=_choose_best_model
    )
    accurate = BashOperator(
        task_id = "accurate",
        bash_command="echo 'accurate'"
    )
    inaccurate = BashOperator(
        task_id = "inaccurate",
        bash_command="echo 'inaccurate'"
    )