# Dynamic distributions

A dynamic distribution changes the total amount of the edge depending on the time the edge is occurring.

**Note**: This behaviour works, but could produce unexpected results, as our initial graph traversal results become unreliable. It is strongly recommended that the data only get smaller (i.e. sum to less than one) - otherwise, the graph traversal will have cutoff some graph paths which could have been significant in the future or past.

The create a dynamic distribution, we need to do the following:

* Create a new temporal distribution class which subclasses `bw_temporalis.TDAware`. The class should do the following:
    * Define `__init__`
    * Define `to_json` and `from_json`
    * Define `__mul__`
    * Define `__add__` if applicable, or prevent it
    * Set `_mul_comes_first = True`
* Register this class in `bw_temporalis.loader_registry`

In [None]:
import bw_temporalis as bwt
import bw2data as bd
import bw2calc as bc

In [None]:
from collections.abc import Mapping
from numbers import Number
from typing import Any, Union
import json
import numpy as np


class LinearDecreaseOverTime(bwt.TDAware):
    # Make sure that we control multiplication
    _mul_comes_first = True

    def __init__(self, start_dt: str, start_value: Number, end_dt: str, end_value: Number, **kwargs: Any):
        self._start_str = start_dt
        self._end_str = end_dt
        
        self.start = np.array(start_dt, dtype="datetime64[s]").astype(int)
        self.end = np.array(end_dt, dtype="datetime64[s]").astype(int)
        if not self.end > self.start:
            raise ValueError("`start` must come before `end`")

        self.a, self.b = float(start_value), float(end_value)

    def __mul__(
        self, other: Union[bwt.TemporalDistribution, Number]
    ) -> Union[bwt.TemporalDistribution, "LinearDecreaseOverTime"]:
        if isinstance(other, bwt.TDAware):
            raise ValueError("Can't multiply two dynamic functions")
        elif isinstance(other, Number):
            return LinearDecreaseOverTime(
                start_dt=self._start_str,
                start_value=self.a * other,
                end_dt=self._end_str,
                end_value=self.b * other
            )
        elif isinstance(other, bwt.TemporalDistribution):
            if not other.base_time_type == bwt.temporal_distribution.datetime_type:
                raise ValueError("Can't multiply by relative distribution")
            new_data = np.array([
                self.value_at_time(value=what, dt=when) 
                for what, when in zip(other.amount, other.date.astype(int))
            ])
            return bwt.TemporalDistribution(
                date=other.date,
                amount=new_data
            )
        else:
            raise ValueError(
                "Can't multiply `LinearDecreaseOverTime` and {}".format(type(other))
            )
    
    def value_at_time(self, value: Number, dt: np.timedelta64) -> float:
        if dt <= self.start:
            return self.a
        elif dt >= self.end:
            return self.b
        else:
            fraction = (self.end - dt) / (self.end - self.start)
            return min(self.a, self.b) + abs(self.a - self.b) * fraction

    def to_json(self) -> str:
        return json.dumps(
            {
                "__loader__": "LinearDecreaseOverTime",
                "start_dt": self._start_str,
                "end_dt": self._end_str,
                "start_value": self.a,
                "end_value": self.b
            }
        )

    @classmethod
    def from_json(cls, json_obj):
        if isinstance(json_obj, Mapping):
            data = json_obj
        elif isinstance(json_obj, str):
            data = json.loads(json_obj)
        else:
            raise ValueError(f"Can't understand `from_json` input object {json_obj}")
        return cls(**data)

        
bwt.loader_registry["LinearDecreaseOverTime"] = LinearDecreaseOverTime.from_json

## A basic database

In [None]:
bd.projects.set_current("dynamic distribution example")

In [None]:
bd.Database('example').write({
    ('example', "CO2"): {
        "type": "emission",
        "name": "carbon dioxide",
        "temporalis code": "co2",
    },
    ('example', 'a'): {
        'name': 'First one',
        'exchanges': [
            {
                'amount': 50,
                'input': ('example', 'b'),
                'temporal_distribution': bwt.easy_timedelta_distribution(
                    start=0,
                    end=10,
                    resolution="Y",
                    steps=11,
                ),
                'type': 'technosphere'
            },
        ],
    },
    ('example', 'b'): {
        'name': 'Second one',
        'exchanges': [
            {
                'amount': 1,
                'input': ('example', 'CO2'),
                'type': 'biosphere',
                'temporal_distribution': LinearDecreaseOverTime(
                    start_dt="2024",
                    end_dt="2028",
                    start_value=1, 
                    end_value=0.2
                )   
            },
        ],
        'type': 'process'
    },
})

In [None]:
bd.Method(("GWP", "example")).write([
    (('example', "CO2"), 1),
])

In [None]:
lca = bc.LCA({('example', 'a'): 1}, ("GWP", "example"))
lca.lci()
lca.lcia()
lca.score

In [None]:
tlca = bwt.TemporalisLCA(lca, cutoff=0.1)
tl = tlca.build_timeline()
tl.build_dataframe()

In [None]:
tl.df.plot(x="date", y="amount", kind="scatter")