# Capture Rate Analysis

## Imports and Constants

In [72]:
import datetime as dt
import logging
import os
import shutil
from typing import List

import plotly.express as px
import polars as pl
from nemosis import cache_compiler, defaults, dynamic_data_compiler, static_table
from tqdm import tqdm

In [73]:
start_time = "2014/01/01 00:00:00"
end_time = "2024/12/31 23:55:00"

start_5ms_date = dt.date(2021, 10, 1)
data_cache = "/home/matthew/Data/nem-capture-rate/"
nemosis_data_cache = "/home/matthew/Data/nemosis/"
aggregated_data_cache = os.path.join(data_cache, "02_aggregated")
sorted_data_cache = os.path.join(data_cache, "03_sorted_compacted")
sorted_rooftop_path = os.path.join(sorted_data_cache, "ROOFTOP_PV_ACTUAL.parquet")
results_dir = os.path.join(data_cache, "99_results")
table_names = ["TRADINGPRICE", "DISPATCHPRICE", "DISPATCHLOAD", "ROOFTOP_PV_ACTUAL"]

In [74]:
INTERVALS_PER_H = 60 // 5

In [75]:
fcas_directions = ["RAISE", "LOWER"]
fcas_speeds = ["1SEC", "6SEC", "5MIN", "REG"]

fcas_markets = [d + s for d in fcas_directions for s in fcas_speeds]

In [76]:
logging.getLogger("nemosis").setLevel(logging.WARNING)

## Download Data

In [77]:
os.makedirs(nemosis_data_cache, exist_ok=True)

In [78]:
# nemosis doesn't include these columns yet
# https://github.com/UNSW-CEEM/NEMOSIS/issues/44
defaults.table_columns["DISPATCHPRICE"] += ["LOWER1SECRRP", "RAISE1SECRRP"]
defaults.table_columns["TRADINGPRICE"] += ["LOWER1SECRRP", "RAISE1SECRRP"]

In [79]:
for table in tqdm(table_names):
    cache_compiler(
        start_time,
        end_time,
        table,
        nemosis_data_cache,
        keep_csv=True,
        fformat="parquet",
        rebuild=False,
    )

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 37.34it/s]


## Preprocess Data

We will compact data (squish many small files into one big file).

We'll also parse datetime strings into datetimes, drop `INTERVENTION` columns, deduplicate and interpolate rooftop solar, and sort each dataset individually.

The docs for polars say that you should do one big query, instead of lots of chained queries, so that the optimiser can do it's thing. However in this case that leads to running out of memory on a normal laptop. So we're going to split it up.

The sorting is to make the subsequent `group_by_dynamic` more memory efficient.

In [80]:
def parse_datetimes(
    lf, cols=["SETTLEMENTDATE"], format="%Y/%m/%d %H:%M:%S"
) -> pl.LazyFrame:
    for col in cols:
        lf = lf.with_columns(
            pl.col(col).str.strptime(pl.Datetime(time_unit="ms"), format=format)
        )
    return lf


def list_raw_files(tablename) -> List[str]:
    return [
        os.path.join(nemosis_data_cache, fname)
        for fname in os.listdir(nemosis_data_cache)
        if fname.endswith(".parquet") and (table in fname)
    ]


def drop_intervention(lf) -> pl.LazyFrame:
    if "INTERVENTION" in lf.collect_schema():
        return lf.filter(pl.col("INTERVENTION") == 0).select(pl.exclude("INTERVENTION"))
    else:
        return lf


def scan_raw_table(tablename) -> pl.LazyFrame:
    return pl.concat(
        [(pl.scan_parquet(path)) for path in list_raw_files(tablename)],
        how="diagonal_relaxed",
    )


# requires collecting
def upsample(lf, time_col, group_col="REGIONID") -> pl.DataFrame:
    return (
        lf.sort(time_col, group_col)
        .collect()
        .upsample(
            time_column=time_col,
            every=dt.timedelta(minutes=5),
            group_by=group_col,
            maintain_order=True,
        )
        .fill_null(strategy="backward", limit=30 // 5 - 1)
    )

In [81]:
os.makedirs(sorted_data_cache, exist_ok=True)

### Dispatchprice and Tradingprice

Compact into one file, because they're small.

In [82]:
table = "TRADINGPRICE"
dest_path = os.path.join(sorted_data_cache, table + ".parquet")
(
    scan_raw_table(table)
    .pipe(parse_datetimes)
    .pipe(drop_intervention)
    .pipe(lambda lf: upsample(lf, time_col="SETTLEMENTDATE"))
    .write_parquet(dest_path)
)

In [83]:
table = "DISPATCHPRICE"
dest_path = os.path.join(sorted_data_cache, table + ".parquet")
(
    scan_raw_table(table)
    .pipe(parse_datetimes)
    .pipe(drop_intervention)
    .sink_parquet(dest_path)
)

### Static Data

Fuel type and region for each generator.



In [84]:
static_pd: "pd.DataFrame" = static_table(
    "Generators and Scheduled Loads", nemosis_data_cache
)

static: pl.LazyFrame = (
    pl.from_pandas(static_pd)
    .lazy()
    # simplify fuel categories
    .with_columns(
        pl.concat_str(
            [
                "Fuel Source - Primary",
                "Technology Type - Descriptor",
                "Fuel Source - Descriptor",
            ],
            separator=" - ",
        )
        .str.to_lowercase()
        .alias("Fuel Detail"),
    )
    .filter(~pl.col("Participant").str.contains("Basslink"))
    .with_columns(
        pl.when(pl.col("Fuel Detail").str.contains("battery"))
        .then(pl.lit("battery"))
        .when(
            pl.col("Fuel Detail").str.contains("hydro")
            & (pl.col("Dispatch Type") == "Load")
        )
        .then(pl.lit("pumps"))
        .when(pl.col("Fuel Detail").str.contains("hydro"))
        .then(pl.lit("hydro"))
        .when(pl.col("Fuel Detail").str.contains("solar"))
        .then(pl.lit("solar_gridscale"))
        .when(pl.col("Fuel Detail").str.contains("wind"))
        .then(pl.lit("wind"))
        .when(pl.col("Fuel Detail").str.contains("waste coal mine gas"))
        .then(pl.lit("coal"))
        .when(
            pl.any_horizontal(
                (
                    pl.col("Fuel Detail").str.contains(s)
                    for s in [
                        "natural gas",
                        "ocgt",
                        "coal seam gas",
                        "coal seam methane",
                    ]
                )
            )
        )
        .then(pl.lit("gas"))
        .when(
            pl.any_horizontal(
                (pl.col("Fuel Detail").str.contains(c))
                # don't blindly search for the word "coal"
                # because coal seam gas should count as gas not coal
                for c in ["black coal", "brown coal"]
            )
        )
        .then(pl.lit("coal"))
        .when(pl.col("Fuel Detail").str.contains("diesel"))
        .then(pl.lit("distillate"))
        .when(pl.col("Fuel Detail").str.contains("biomass"))
        .then(pl.lit("biomass"))
        .when(pl.col("DUID").str.contains("PUMP") & (pl.col("Dispatch Type") == "Load"))
        .then(pl.lit("pumps"))
        .alias("FUEL_TYPE")
    )
    .select("DUID", pl.col("Region").alias("REGIONID"), "FUEL_TYPE")
)

# check we've categorised everything
assert not static.select(
    pl.col("FUEL_TYPE").is_null().any().alias("UNKNOWN")
).collect()["UNKNOWN"][0]



### Dispatchload

We'll join this to static data to get region and fuel type. Then compact into one row per (time interval, fuel type, region).
Since this is so large, do it one file at a time. Then compact the final result into one file.

In [85]:
table = "DISPATCHLOAD"

dest_dir = os.path.join(aggregated_data_cache, table)
shutil.rmtree(dest_dir, ignore_errors=True)
os.makedirs(dest_dir)


agged_files = []

for source_path in tqdm(list_raw_files(table)):
    assert os.path.exists(source_path), f"Source path doesn't exist: {source_path}"
    dest_path = os.path.join(dest_dir, os.path.basename(source_path))
    assert source_path != dest_path
    agged_files.append(dest_path)

    try:
        schema = pl.read_parquet_schema(source_path)
        present_fcas_cols = [c for c in fcas_markets if c in schema]
        (
            pl.scan_parquet(source_path)
            .pipe(drop_intervention)
            .join(static, on="DUID")
            .with_columns(
                pl.col("FUEL_TYPE").fill_null(value="UNKNOWN"),
                pl.mean_horizontal("INITIALMW", "TOTALCLEARED").alias("POWER"),
            )
            .group_by("SETTLEMENTDATE", "REGIONID", "FUEL_TYPE")
            .agg(
                pl.col("POWER").sum(),
                pl.col(present_fcas_cols).sum(),
            )
            .pipe(parse_datetimes)
            .sink_parquet(dest_path)
        )
    except Exception:
        print(f"Error with {source_path}")
        if os.path.exists(dest_path):
            os.remove(dest_path)
        raise

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 68/68 [00:12<00:00,  5.33it/s]


After aggregating by fuel type (and region), now we have lots of tiny files we can compact into one.

In [86]:
dest_path = os.path.join(sorted_data_cache, table + ".parquet")
(
    pl.concat(
        [pl.scan_parquet(f) for f in agged_files], how="diagonal_relaxed"
    ).sink_parquet(dest_path)
)

### Rooftop Data

In [87]:
table = "ROOFTOP_PV_ACTUAL"
rooftop_duplicated = scan_raw_table(table)

rooftop_duplicated.head().collect()

INTERVAL_DATETIME,REGIONID,POWER,QI,TYPE,LASTCHANGED
str,str,f64,f64,str,str
"""2020/02/01 00:30:00""","""NSW1""",0.0,1.0,"""MEASUREMENT""","""2020/02/01 00:50:32"""
"""2020/02/01 00:30:00""","""NSW1""",0.0,0.6,"""SATELLITE""","""2020/02/01 00:50:23"""
"""2020/02/01 00:30:00""","""QLD1""",0.0,0.6,"""SATELLITE""","""2020/02/01 00:50:23"""
"""2020/02/01 00:30:00""","""QLD1""",0.0,1.0,"""MEASUREMENT""","""2020/02/01 00:50:32"""
"""2020/02/01 00:30:00""","""QLDC""",0.0,1.0,"""MEASUREMENT""","""2020/02/01 00:50:32"""


In [88]:
# deduplicate across the different methods of estimation
rooftop_30 = (
    rooftop_duplicated.pipe(drop_intervention)
    .filter(pl.col("REGIONID").str.ends_with("1"))
    # deduplicate
    .sort(by=["TYPE", "QI", "LASTCHANGED"], descending=[False, True, True])
    .group_by(["REGIONID", "INTERVAL_DATETIME"])
    .first()
    .select(["REGIONID", "INTERVAL_DATETIME", "POWER"])
    .rename({"INTERVAL_DATETIME": "INTERVAL_END"})
)

In [89]:
# now interpolate rooftop solar from 30 minutes to 5

(
    rooftop_30
    .pipe(lambda lf: parse_datetimes(lf, cols=["INTERVAL_END"]))
    .pipe(lambda lf: upsample(lf, time_col="INTERVAL_END"))
    .write_parquet(sorted_rooftop_path)
)

## Scan Data

In [90]:
dispatchprice = pl.scan_parquet(
    os.path.join(sorted_data_cache, "DISPATCHPRICE.parquet"), low_memory=True
)
assert "INTERVENTION" not in dispatchprice.collect_schema()
dispatchprice.head().collect()

SETTLEMENTDATE,REGIONID,RRP,RAISE6SECRRP,RAISE60SECRRP,RAISE5MINRRP,RAISEREGRRP,LOWER6SECRRP,LOWER60SECRRP,LOWER5MINRRP,LOWERREGRRP,PRICE_STATUS,RAISE1SECRRP,LOWER1SECRRP
datetime[ms],str,f64,f64,f64,f64,f64,f64,f64,f64,f64,str,f64,f64
2020-03-01 00:05:00,"""NSW1""",47.91176,4.9,3.89,0.99,10.99,0.45,0.45,0.18,8.5,"""FIRM""",,
2020-03-01 00:05:00,"""QLD1""",47.73009,4.9,3.89,0.99,10.99,0.45,0.45,0.18,8.5,"""FIRM""",,
2020-03-01 00:05:00,"""SA1""",301.0,4.9,3.89,0.99,10.99,0.0,0.0,0.18,8.5,"""FIRM""",,
2020-03-01 00:05:00,"""TAS1""",50.06809,4.9,3.89,0.99,10.99,0.0,0.18,0.18,8.5,"""FIRM""",,
2020-03-01 00:05:00,"""VIC1""",52.27487,4.9,3.89,0.99,10.99,0.45,0.45,0.18,8.5,"""FIRM""",,


In [91]:
tradingprice = pl.scan_parquet(
    os.path.join(sorted_data_cache, "*TRADINGPRICE*.parquet"), low_memory=True
)
assert "INTERVENTION" not in tradingprice.collect_schema()
tradingprice.head().collect()

SETTLEMENTDATE,REGIONID,RRP,RAISE6SECRRP,RAISE60SECRRP,RAISE5MINRRP,RAISEREGRRP,LOWER6SECRRP,LOWER60SECRRP,LOWER5MINRRP,LOWERREGRRP,PRICE_STATUS,RAISE1SECRRP,LOWER1SECRRP
datetime[ms],str,f64,f64,f64,f64,f64,f64,f64,f64,f64,str,f64,f64
2013-12-01 00:30:00,"""NSW1""",48.98,0.21,0.17,0.63,0.63,0.2,0.18,0.27,0.27,"""FIRM""",,
2013-12-01 00:35:00,"""NSW1""",47.79,0.41,0.17,0.68,1.0,0.09,0.18,0.27,0.63,"""FIRM""",,
2013-12-01 00:40:00,"""NSW1""",47.79,0.41,0.17,0.68,1.0,0.09,0.18,0.27,0.63,"""FIRM""",,
2013-12-01 00:45:00,"""NSW1""",47.79,0.41,0.17,0.68,1.0,0.09,0.18,0.27,0.63,"""FIRM""",,
2013-12-01 00:50:00,"""NSW1""",47.79,0.41,0.17,0.68,1.0,0.09,0.18,0.27,0.63,"""FIRM""",,


In [92]:
actual = (
    tradingprice.select(
        (pl.col("SETTLEMENTDATE").shift(-1) - pl.col("SETTLEMENTDATE")).alias(
            "INTERVAL_LENGTH"
        )
    )
    .head(1)
    .collect()
    .item()
)
expected = dt.timedelta(minutes=5)
assert actual == expected, "tradingprice not upsampled"

In [93]:
rooftop = pl.scan_parquet(sorted_rooftop_path, low_memory=True)
assert "INTERVENTION" not in rooftop.collect_schema()
rooftop.head().collect()

INTERVAL_END,REGIONID,POWER
datetime[ms],str,f64
2016-08-01 00:30:00,"""NSW1""",0.0
2016-08-01 00:35:00,"""NSW1""",0.0
2016-08-01 00:40:00,"""NSW1""",0.0
2016-08-01 00:45:00,"""NSW1""",0.0
2016-08-01 00:50:00,"""NSW1""",0.0


In [94]:
source_path = os.path.join(sorted_data_cache, "DISPATCHLOAD.parquet")

dispatchload = pl.scan_parquet(source_path, low_memory=True)
assert "INTERVENTION" not in dispatchload.collect_schema()
dispatchload.head().collect()

SETTLEMENTDATE,REGIONID,FUEL_TYPE,POWER,RAISE6SEC,RAISE5MIN,RAISEREG,LOWER6SEC,LOWER5MIN,LOWERREG,RAISE1SEC,LOWER1SEC
datetime[ms],str,str,f64,f64,f64,f64,f64,f64,f64,f64,f64
2014-03-01 05:10:00,"""NSW1""","""coal""",3808.413265,160.0,214.13789,83.0,16.29999,54.20743,31.0,,
2014-03-01 06:50:00,"""NSW1""","""hydro""",54.9201,0.0,0.0,0.0,0.0,0.0,0.0,,
2014-03-01 08:20:00,"""QLD1""","""coal""",4539.060905,67.125,34.0,68.0,0.0,12.0,46.0,,
2014-03-01 09:05:00,"""QLD1""","""coal""",4579.92501,51.7875,33.09999,38.90001,0.0,12.0,54.0,,
2014-03-01 10:00:00,"""QLD1""","""hydro""",137.21,0.0,60.0,0.0,0.0,0.0,0.0,,


Resistor Revenue

In [95]:
neg_results = (
    dispatchprice
    .filter(pl.col("RRP") < 0)
    .sort("SETTLEMENTDATE")
    .group_by_dynamic("SETTLEMENTDATE", every="1mo", group_by="REGIONID")
    .agg(
        (-pl.col("RRP") / INTERVALS_PER_H).sum().alias("REVENUE")
    )
    .collect()
)

In [96]:

fig = px.line(
    neg_results,
    x="SETTLEMENTDATE",
    y="REVENUE",
    color="REGIONID",
    title=f"Revenue ($)"
)

#fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))

fig.show()

## Computation

In [97]:
os.makedirs(results_dir, exist_ok=True)

* dispatchload joined with static
* concat with rooftop
* join with prices
* capture rate over time
* plot

In [98]:
regionid_enum = pl.Enum(["QLD1", "NSW1", "VIC1", "TAS1", "SA1"])
fuel_type_enum = pl.Enum(
    [
        "gas",
        "coal",
        "solar_rooftop",
        "solar_gridscale",
        "wind",
        "hydro",
        "distillate",
        "pumps",
        "battery",
    ]
)

In [99]:
def calculate(metric):
    return (
        pl.concat(
            [
                dispatchload.rename({"SETTLEMENTDATE": "INTERVAL_END"}),
                (rooftop.with_columns(pl.lit("solar_rooftop").alias("FUEL_TYPE"))),
            ],
            how="diagonal_relaxed",
        )
        .with_columns((pl.col("POWER") / INTERVALS_PER_H).alias("ENERGY_MWH"))
        .cast(
            {
                "REGIONID": regionid_enum,
                "FUEL_TYPE": fuel_type_enum,
            }
        )
        .join(
            tradingprice.with_columns(
                pl.col("SETTLEMENTDATE").alias("INTERVAL_END"),
                pl.col("RRP").alias("ENERGY_PRICE"),
            ).cast(
                {
                    "REGIONID": regionid_enum,
                }
            ),
            on=["REGIONID", "INTERVAL_END"],
        )
        .with_columns(
            [
                (pl.col("POWER") * pl.col("RRP") / INTERVALS_PER_H).alias("ENERGY_REVENUE"),
                (
                    pl.sum_horizontal([pl.col(m) * pl.col(m + "RRP") for m in fcas_markets])
                    / INTERVALS_PER_H
                ).alias("FCAS_REVENUE"),
            ]
        )
        .with_columns(
            (pl.col("ENERGY_REVENUE") + pl.col("FCAS_REVENUE")).alias("TOTAL_REVENUE")
        )
        .with_columns(
            pl.col("ENERGY_MWH")
            .sum()
            .over("REGIONID", "INTERVAL_END")
            .alias("REGION_INTERVAL_MWH"),
            pl.col("TOTAL_REVENUE")
            .sum()
            .over("REGIONID", "INTERVAL_END")
            .alias("REGION_INTERVAL_REVENUE"),
        )
        .sort("INTERVAL_END")
        .group_by_dynamic(
            index_column="INTERVAL_END",
            every="1y",
            label="left",
            group_by=["FUEL_TYPE"],  # "REGIONID",
        )
        .agg(
            pl.col("ENERGY_PRICE").mean().alias("TWAP"),
            pl.col("REGION_INTERVAL_REVENUE").sum().alias("REGION_REVENUE"),
            pl.col("REGION_INTERVAL_MWH").sum().alias("REGION_MWH"),
            (pl.col("TOTAL_REVENUE").sum() / pl.col("ENERGY_MWH").sum()).alias(
                "CAPTURE_PRICE"
            ),
            pl.len().alias("N"),
        )
        .rename({"INTERVAL_END": "INTERVAL_START"})
        .filter(
            pl.col("N") > 24 * INTERVALS_PER_H
        )  # remove end periods with only a few samples
        .with_columns(
            (pl.col("REGION_REVENUE") / pl.col("REGION_MWH")).alias("REGION_GWAP")
        )
        .with_columns(
            (pl.col("CAPTURE_PRICE") / pl.col("TWAP")).alias("CAPTURE_RATE"),
            (pl.col("CAPTURE_PRICE") / pl.col("REGION_GWAP")).alias("PARTICIPATION_FACTOR"),
        )
        .select("INTERVAL_START", "FUEL_TYPE", metric)
        .sort("INTERVAL_START", "FUEL_TYPE")  # "REGIONID",
        .collect()
    )


In [100]:
df_pf = calculate("PARTICIPATION_FACTOR")
df_pf.write_csv(os.path.join(results_dir, f"results-participation-factor.csv"))

In [101]:
df_cr = calculate("CAPTURE_RATE")
df_cr.write_csv(os.path.join(results_dir, f"results-capture-rate.csv"))

In [109]:
df_cp = calculate("CAPTURE_PRICE")
df_cp.write_csv(os.path.join(results_dir, f"results-cp.csv"))


INTERVAL_START,FUEL_TYPE,CAPTURE_PRICE
datetime[ms],enum,f64
2024-01-01 00:00:00,"""battery""",-11154.421338
2022-01-01 00:00:00,"""distillate""",-3174.528453
2022-01-01 00:00:00,"""pumps""",0.886779
2023-01-01 00:00:00,"""pumps""",15.174493
2024-01-01 00:00:00,"""pumps""",21.432418
…,…,…
2016-01-01 00:00:00,"""distillate""",424.858951
2017-01-01 00:00:00,"""distillate""",849.233342
2023-01-01 00:00:00,"""distillate""",1055.178924
2024-01-01 00:00:00,"""distillate""",2050.924412


In [110]:
# "In 2024 gridscale solar power was solar for an average of x \$/MWh  #todo[get number], compared to y \$/MWh for gas."
(
    df_cp
    .filter(pl.col("INTERVAL_START").dt.year() == 2024)
    .sort("CAPTURE_PRICE")
)

INTERVAL_START,FUEL_TYPE,CAPTURE_PRICE
datetime[ms],enum,f64
2024-01-01 00:00:00,"""battery""",-11154.421338
2024-01-01 00:00:00,"""pumps""",21.432418
2024-01-01 00:00:00,"""solar_rooftop""",30.420047
2024-01-01 00:00:00,"""solar_gridscale""",53.290105
2024-01-01 00:00:00,"""wind""",77.392015
2024-01-01 00:00:00,"""coal""",125.541581
2024-01-01 00:00:00,"""hydro""",191.331475
2024-01-01 00:00:00,"""gas""",265.969536
2024-01-01 00:00:00,"""distillate""",2050.924412


In [102]:
color_mapping = {
    "coal": "brown",
    "gas": "black",
    "solar_gridscale": "orange",
    "solar_rooftop": "yellow",
    "wind": "green",
    #"hydro": "blue",
}

metrics = [
    {
        'df': df_pf,
        'title': "Participation Factor",
        'col': "PARTICIPATION_FACTOR",
    },
    {
        'df': df_cr,
        'title': "Capture Rate",
        'col': "CAPTURE_RATE",
    },
]

for metric in metrics:

    fig = px.line(
        metric['df'].filter(pl.col("FUEL_TYPE").is_in(color_mapping)),
        x="INTERVAL_START",
        y=metric['col'],
        color="FUEL_TYPE",
        color_discrete_map=color_mapping,
        height=400,
        width=800,
        title=f"{metric['title']} over time by fuel type (whole NEM)",
        labels={
            "FUEL_TYPE": "Fuel Type",
            "INTERVAL_START": "Time",
            metric['col']: metric['title'],
        }
    )
    
    fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
    slug = metric['title'].replace(' ', '_')
    fig.write_image(os.path.join(results_dir, f"all-regions-{slug}.svg"))
    
    fig.show()

In [103]:
# same again, but just solar in VIC and SA

df_solar = (
    pl.concat(
        [
            dispatchload
            .rename({"SETTLEMENTDATE": "INTERVAL_END"})
            .filter(pl.col("FUEL_TYPE") == "solar_gridscale"),
            (rooftop.with_columns(pl.lit("solar_rooftop").alias("FUEL_TYPE"))),
        ],
        how="diagonal_relaxed",
    )
    .filter((pl.col("REGIONID") == "SA1") | (pl.col("REGIONID") == "VIC1"))
    .with_columns((pl.col("POWER") / INTERVALS_PER_H).alias("ENERGY_MWH"))
    .cast(
        {
            "REGIONID": regionid_enum,
            "FUEL_TYPE": fuel_type_enum,
        }
    )
    .join(
        tradingprice.with_columns(
            pl.col("SETTLEMENTDATE").alias("INTERVAL_END"),
            pl.col("RRP").alias("ENERGY_PRICE"),
        ).cast(
            {
                "REGIONID": regionid_enum,
            }
        ),
        on=["REGIONID", "INTERVAL_END"],
    )
    .with_columns(
        [
            (pl.col("POWER") * pl.col("RRP") / INTERVALS_PER_H).alias("ENERGY_REVENUE"),
            (
                pl.sum_horizontal([pl.col(m) * pl.col(m + "RRP") for m in fcas_markets])
                / INTERVALS_PER_H
            ).alias("FCAS_REVENUE"),
        ]
    )
    .with_columns(
        (pl.col("ENERGY_REVENUE") + pl.col("FCAS_REVENUE")).alias("TOTAL_REVENUE")
    )
    .with_columns(
        pl.col("ENERGY_MWH")
        .sum()
        .over("REGIONID", "INTERVAL_END")
        .alias("REGION_INTERVAL_MWH"),
        pl.col("TOTAL_REVENUE")
        .sum()
        .over("REGIONID", "INTERVAL_END")
        .alias("REGION_INTERVAL_REVENUE"),
    )
    .sort("INTERVAL_END")
    .group_by_dynamic(
        index_column="INTERVAL_END",
        every="3mo",
        label="left",
        group_by=["FUEL_TYPE", "REGIONID"],  # "REGIONID",
    )
    .agg(
        pl.col("ENERGY_PRICE").mean().alias("TWAP"),
        pl.col("REGION_INTERVAL_REVENUE").sum().alias("REGION_REVENUE"),
        pl.col("REGION_INTERVAL_MWH").sum().alias("REGION_MWH"),
        (pl.col("TOTAL_REVENUE").sum() / pl.col("ENERGY_MWH").sum()).alias(
            "CAPTURE_PRICE"
        ),
        pl.len().alias("N"),
    )
    .rename({"INTERVAL_END": "INTERVAL_START"})
    .filter(
        pl.col("N") > 24 * INTERVALS_PER_H
    )  # remove end periods with only a few samples
    .with_columns(
        (pl.col("REGION_REVENUE") / pl.col("REGION_MWH")).alias("REGION_GWAP")
    )
    .with_columns(
        (pl.col("CAPTURE_PRICE") / pl.col("TWAP")).alias("CAPTURE_RATE"),
        (pl.col("CAPTURE_PRICE") / pl.col("REGION_GWAP")).alias("PARTICIPATION_FACTOR"),
    )
    .select("INTERVAL_START", "FUEL_TYPE", "CAPTURE_RATE", "REGIONID")
    .sort("INTERVAL_START", "FUEL_TYPE")  # "REGIONID",
    .collect()
)


In [104]:

fig = px.line(
    df_solar,
    x="INTERVAL_START",
    y="CAPTURE_RATE",
    color="FUEL_TYPE",
    #color_discrete_map=color_mapping,
    facet_col="REGIONID",
    height=400,
    width=800,
    title=f"{metric['title']} over time for solar in VIC and SA",
    labels={
        "FUEL_TYPE": "Fuel Type",
        "INTERVAL_START": "Time",
        "CAPTURE_RATE": "Capture Rate",
        "REGIONID": "Region"
    }
)

fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
fig.write_image(os.path.join(results_dir, f"vic-sa-solar.svg"))

fig.show()