In [None]:
# ruff: noqa: F401
from __future__ import annotations

import os
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from typing import TYPE_CHECKING

import django
from django.conf import settings  # pyright: ignore[reportUnusedImport]

# import altair as alt
import polars as pl
import yaml

# from great_tables import GT

# Allow Django to run in async environments (like Jupyter)
os.environ["DJANGO_ALLOW_ASYNC_UNSAFE"] = "true"

# Set the Django settings module
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'paths.settings')

# Configure Django
django.setup()

from common import polars as ppl  # noqa: E402
from nodes.constants import FORECAST_COLUMN, VALUE_COLUMN, YEAR_COLUMN  # noqa: E402
from nodes.exceptions import NodeComputationError  # noqa: E402
from nodes.units import Quantity, unit_registry  # noqa: E402
from notebooks.notebook_support import get_context, get_nodes, initialize_notebook_env  # noqa: E402

initialize_notebook_env()

if TYPE_CHECKING:
    from common.polars import PathsDataFrame


There was a previous version of this functionality. It was organized as a Django management command. It had the benefit that it was accessing the production instances directly, but otherwise it was very cumbersome code.
https://github.com/kausaltech/kausal-paths/blob/4fd525e23bb96cefdd991a10b3354daf9fe3be3c/nodes/management/commands/collect_city_data.py

This code version may be better as a terminal command, but I'll keep the Jupyter notebook version for now.

In [None]:
config_file = '../netzeroplanner-framework-config/emission_potential.yaml'

In [None]:
@dataclass
class NodeData:
    """Individual node with its dataframe."""

    id: str
    df: ppl.PathsDataFrame


@dataclass
class InstanceData:
    """Instance containing multiple nodes."""

    id: str
    target_year: int
    nodes: list[NodeData] = field(default_factory=list)

    def add_node(self, node_id: str, df: ppl.PathsDataFrame) -> NodeData:
        """Add a node to this instance."""

        node = NodeData(id=node_id, df=df)
        self.nodes.append(node)
        return node

    def get_node_df(self, node_id: str) -> ppl.PathsDataFrame | None:
        """Get a specific node df by id."""

        node = next((node for node in self.nodes if node.id == node_id), None)
        if node is None:
            return None
        return node.df

    def update_node_df(self, node_id: str, df: ppl.PathsDataFrame) -> InstanceData:
        node = next((node for node in self.nodes if node.id == node_id), None)
        assert node is not None
        node.df = df
        return self


@dataclass
class DataCollection:
    """Main container for all dc."""

    output_path: str
    output_date: str
    processors: list[str] = field(default_factory=list)
    logs: list[str] = field(default_factory=list)
    instances: list[InstanceData] = field(default_factory=list)
    summaries: list[InstanceData] = field(default_factory=list)

    def add_instance(self, instance_id: str, target_year: int) -> InstanceData:
        """Add a new instance."""
        instance = InstanceData(id=instance_id, target_year=target_year)
        self.instances.append(instance)
        return instance

    def get_instance(self, instance_id: str) -> InstanceData | None:
        """Get a specific instance by id."""

        return next((inst for inst in self.instances if inst.id == instance_id), None)


In [None]:
def read_config(yaml_file):
    config = yaml.safe_load(Path(yaml_file).open('r'))  # noqa: SIM115
    return config

def find_target_values(dc: DataCollection) -> DataCollection:
    for instance in dc.instances:
        for node in instance.nodes:
            df: ppl.PathsDataFrame = node.df
            meta = df.get_meta()
            target_year = instance.target_year
            obs_year = df.filter(~pl.col(FORECAST_COLUMN))[YEAR_COLUMN].max()
            df = (
                df.filter(pl.col(YEAR_COLUMN).is_in([obs_year, target_year]))
                .sort(by=[YEAR_COLUMN])
            )
            df = df.with_columns(
                pl.when(pl.col(YEAR_COLUMN) == obs_year)
                .then(pl.lit('newest'))
                .otherwise(pl.lit('target'))
                .alias('param')
            )
            df = ppl.to_ppdf(df, meta).add_to_index('param')
            instance.update_node_df(node.id, df)
    return dc

def convert_to_target_units(dc: DataCollection) -> DataCollection:
    multipliers: dict[str, Quantity] = {
        'kt_co2e/a': unit_registry('1 * kt/kt_co2e'),
    }
    for instance in dc.instances:
        for node in instance.nodes:
            df: PathsDataFrame = node.df
            df_unit = df.get_meta().units[VALUE_COLUMN]
            for from_unit, to_unit in multipliers.items():
                if df_unit.is_compatible_with(from_unit):
                    df = df.multiply_quantity(VALUE_COLUMN, to_unit)
                    instance.update_node_df(node.id, df)
    return dc

def sum_over_dfs(dc: DataCollection) -> DataCollection:
    for instance in dc.instances:
        for node in instance.nodes:
            df = node.df
            dropcols = [dim for dim in df.primary_keys if dim != YEAR_COLUMN]
            df = df.paths.sum_over_dims(dropcols)
            instance.update_node_df(node_id=node.id, df=df)
    return dc

def sum_over_instances(dc: DataCollection) -> DataCollection:
    # node_ids = list({node.id for instance in dc.instances for node in instance.nodes})
    summary = InstanceData(id='sum_over_instances', target_year=0)
    for instance in dc.instances:
        for node in instance.nodes:
            df: PathsDataFrame = node.df
            sum_df: PathsDataFrame | None = summary.get_node_df(node.id)
            if sum_df is None:
                summary.add_node(node.id, df)
            elif set(sum_df.primary_keys) == set(df.primary_keys):
                summary.update_node_df(node.id, sum_df.paths.add_df(odf=df))
            else:
                dc.logs.append("".join([
                    f"Node {node.id} has primary keys {df.primary_keys} in instance {instance.id}",
                    f" but expected {sum_df.primary_keys}."]))
    dc.summaries.append(summary)

    return dc

def report_log(dc: DataCollection) -> None:
    print("\nDuring processing, the following things happened:")
    for log in dc.logs:
        print(log)

def save_summaries(dc: DataCollection) -> DataCollection:
    dc.logs.append("Saving summaries about:")
    output_path = dc.output_path
    for summary in dc.summaries:
        dc.logs.append(f"- {summary.id}.")
        for node in summary.nodes:
            output_file = f"{output_path}{summary.id}_{node.id}.csv"
            node.df.write_csv(output_file)
            dc.logs.append(f"  - Saved dc in {output_file}.")
    return dc

def no_processing(dc: DataCollection) -> DataCollection:
    return dc

postprocess_data = {
    'convert_to_target_units': convert_to_target_units,
    'find_target_values': find_target_values,
    'save_summaries': save_summaries,
    'sum_over_dfs': sum_over_dfs,
    'sum_over_instances': sum_over_instances,
    'none': no_processing,
}

In [None]:
config = read_config(config_file)
processors = config.get('processors')
output_path = config.get('output_path')
output_date: str = str(datetime.now().strftime("%Y-%m-%d %H:%M:%S"))  # noqa: DTZ005

data = DataCollection(output_path, output_date, processors=processors)

instances = config['instances']
# instances = instances[0:10] # Used to simplify testing
node_ids = [node['id'] for node in config['nodes']]

for instance_id in instances:
    nodes = get_nodes(instance_id)
    target_year = get_context(instance_id).target_year
    instance = data.add_instance(instance_id=instance_id, target_year=target_year)
    for node_id in node_ids:
        node = nodes.get(node_id)
        if node is None:
            data.logs.append(f"Node {node_id} not found in instance {instance.id}.")
            continue
        try:
            df = node.get_output_pl()
            instance.add_node(node_id=node_id, df=df)
        except (ValueError, NodeComputationError):
            data.logs.append(f"Node {node_id} in instance {instance.id} gave and error and is skipped.")
            continue

for processor in data.processors:
    if processor not in postprocess_data.keys():
        data.logs.append(f"Processor {processor} is not defined. Ignoring.")
        continue
    data = postprocess_data[processor](data)



In [None]:
report_log(data)