# Graph Construction

In [None]:
from ppcascade.fluent import from_source
from ppcascade.utils.window import Range
from ppcascade.utils.request import Request
from cascade.graph import deduplicate_nodes

GRID = "O320"     # For higher memory usage increase to O640 or O1280
NUM_ENSEMBLES = 2 # Can increase up to 50
SOURCE = "mars"   # Change to fdb if using local fdb. Note FDB_HOME will need to be specified 
                  # and to change the grid from that archived in the FDB you will need to add 
                  # the commented out interpolation line in the request and remove the "grid" 
                  # key
END_STEP = 60     # Can increase to a number divisible by 6 up to 240
DATE = "20241015"
CLIM_DATE = "20241014"

inputs = from_source([{
        "class": "od", 
        "expver": "0001", 
        "stream": "enfo", 
        "date": DATE, 
        "time": "00", 
        "param": 167, 
        "levtype": "sfc", 
        "type": "pf", 
        "number": range(1, NUM_ENSEMBLES + 1), 
        "step": range(0, END_STEP + 1, 3),
        "source": SOURCE, 
        "grid": GRID,
        # "interpolate": {"grid": GRID},
        }])

# Graph for computing ensemble mean and standard deviation for each time step
ensms_graph = (
    inputs     
    .ensemble_operation("ensms", dim="number", batch_size=2)
    .graph()
)

# Graph for computing probability of exceeding a certain threshold across a step window
prob_windows = [
    Range(f"{x}-{x}", [x]) for x in range(0, END_STEP + 1, 24) 
] + [
    Range(f"{x}-{x + 120}", list(range(x + 6, x + 121, 6))) for x in range(0, END_STEP  - 119, 120)
]
prob_graph = (
    inputs
    .window_operation(
        "min", 
        prob_windows,
        dim="step", batch_size=2)
    .ensemble_operation(
        "threshold_prob", 
        comparison="<=", 
        local_scale_factor=2, 
        value= 273.15,
    )
    .graph()
)

# Graph for computing extreme forecast indices - this has the highest memory consumption 
efi_windows = [
    Range(f"{x}-{x+24}", list(range(x+6, x+25, 6))) for x in range(0, END_STEP - 23, 24)
]
climatology = from_source(
    [Request({
        "class": "od", 
        "expver": "0001", 
        "stream": "efhs", 
        "date": CLIM_DATE, 
        "time": "00", 
        "param": 228004, 
        "levtype": "sfc", 
        "type": "cd", 
        "quantile": [f"{x}:100" for x in range(101)],
        "step": [f"{x}-{x+24}" for x in range(0, END_STEP - 23, 24)],
        "source": SOURCE, 
        "grid": GRID,
        # "interpolate": {"grid": GRID},
        }, no_expand=("quantile",))]
)
efi_graph = (
    inputs
    .window_operation(
        "mean", 
        efi_windows,
        dim="step", batch_size=2)
    .ensemble_extreme(
        "extreme",
        climatology,
        efi_windows,
        sot=[10, 90],
        eps=1e-4,
        metadata={
            "edition": 1,
            "gribTablesVersionNo": 132,
            "indicatorOfParameter": 167,
            "localDefinitionNumber": 19,
            "timeRangeIndicator": 3
        }
    )
    .graph()
)

total_graph = deduplicate_nodes(ensms_graph + prob_graph + efi_graph)

In [None]:
from cascade.visualise import visualise

visualise(total_graph, "pproc_2t.html")

In [None]:
from cascade.cascade import Cascade

cas = Cascade(efi_graph)
cas.benchmark("2t_memray")