# State-by-state 2pp voting trends
- Gaussian Random Walk model
- but no house-effects model
- and not constrained back to any national model

__NOTE:__
* Please run the *_data_capture.ipynb* notebook before running this notebook.

## Set-up

In [None]:
# system imports
from pathlib import Path
from typing import Any, NotRequired, TypedDict
from functools import cache

# analytic imports
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from IPython.display import display

# PyMC imports
import arviz as az  # type: ignore[import-untyped]
import pymc as pm  # type: ignore[import-untyped]

In [None]:
# local import
import bayes_tools
import plotting
from common import (
    MIDDLE_DATE,
    VOTING_INTENTION,
    ensure,
)
from data_capture import retrieve

In [None]:
# plotting related
SHOW = False  # show charts in the notebook

# model diagram
MODEL_DIR = "../model-images/"
Path(MODEL_DIR).mkdir(parents=True, exist_ok=True)

### Check operating environment

In [None]:
%load_ext watermark
%watermark --python --machine --conda --iversions --watermark

## Get data

In [None]:
ORDERED_STATES = ["NSW", "VIC", "QLD", "WA", "SA", "TAS", "NT", "ACT"]

### State poll results

In [None]:
def get_poll_data() -> dict[str, pd.DataFrame]:
    """Retrieve raw data after running the data capture notebook."""

    retrieved = retrieve()
    ensure(retrieved, "You must run the data capture notebook every day.")
    return retrieved

In [None]:
def extract_state_polls(
    raw: dict[str, pd.DataFrame] | None = None,
    minimum: int = 10,  # minimum required number of polls
    column: str = "2pp vote ALP",
) -> dict[str, pd.DataFrame]:
    """Retrieve the 2pp state polls."""

    if raw is None:
        raw = get_poll_data()

    cooked_data = {}
    for state in [VOTING_INTENTION] + ORDERED_STATES:
        if state not in raw:
            print(f"Missing {state} data")
            continue
        data = raw[state].copy()  # don't modify the raw data
        data.index = pd.PeriodIndex(data[MIDDLE_DATE], freq="M")
        data = data.sort_index(ascending=True)[column]

        if len(data) < minimum:
            print(f"Not enough data for {state}: n={len(data)}")
            continue

        cooked_data[state if state in ORDERED_STATES else "Australia"] = data

    return cooked_data


POLLS = extract_state_polls()
print(POLLS.keys())

### Previous election - state and national starting points

In [None]:
previous_election = {
    # 2pp vote ALP
    "Australia": 52.13,
    "NSW": 51.42,
    "VIC": 54.83,
    "QLD": 45.95,
    "WA": 55,
    "SA": 53.97,
    "TAS": 54.33,
    "ACT": 66.95,
    "NT": 55.54,
}

## Data preparation

### National Data

In [None]:
def get_timeline() -> tuple[pd.Period, pd.Period]:
    """Get the start and end of the timeline for the national data."""

    national = retrieve()
    national_dates = pd.PeriodIndex(
        national[VOTING_INTENTION][MIDDLE_DATE].dropna(), freq="M"
    )

    # - time frames
    first_month = national_dates.min() - 1 # 1 month before the first poll
    last_month = national_dates.max()
    return first_month, last_month


FIRST_MONTH, LAST_MONTH = get_timeline()
N_MONTHS = (LAST_MONTH - FIRST_MONTH).n + 1
assert N_MONTHS > 1

### State data

In [None]:
def build_states_data() -> dict[str, dict[str, Any]]:
    """Build the state data for the model."""

    states_data: dict[str, dict[str, Any]] = {}
    for state in POLLS.keys():
        state_data: dict[str, Any] = {}
        y = POLLS[state]
        assert y.index.min() >= FIRST_MONTH
        assert y.index.max() <= LAST_MONTH
        state_data["y"] = y.to_numpy()
        state_data["n_polls"] = len(POLLS[state])
        state_data["start"] = previous_election[state]
        state_data["poll_month"] = [(x - FIRST_MONTH).n for x in y.index]
        states_data[state] = state_data

    return states_data


STATES_DATA = build_states_data()
print(STATES_DATA.keys())

## Build a model to estimate state based voting intention

### Sampler settings

In [None]:
class SampleSettings(TypedDict):
    """The settings for the Bayesian model."""

    draws: int
    tune: int
    cores: int
    chains: int
    nuts_sampler: str
    nuts: NotRequired[dict[str, Any]]
    plot_trace: NotRequired[bool]


def sampler_settings() -> SampleSettings:
    """Return the settings for sampling."""

    core_chains = 5
    settings: SampleSettings = {
        "draws": 2_000,  # number of samples per core
        "tune": 2_000,  # number of tuning steps per core
        "cores": core_chains,
        "chains": core_chains,
        "nuts_sampler": "numpyro",
        "plot_trace": False,
    }

    return settings

### Model build

In [None]:
def build_model(
    for_state: str,
) -> pm.Model:
    """Builds a simple PyMC model for a monthly Gaussian Random Walk.
    NOTE: model is working with whole percentage points, not fractions."""
    
    # -- state specifiv data
    assert for_state in STATES_DATA
    params = STATES_DATA[for_state]
    start = params["start"]
    print(for_state)

    # -- model
    n_polls = params["n_polls"]
    coords = {
        "months": range(N_MONTHS),
        "polls": range(n_polls),
    }

    with (model := pm.Model(coords=coords)):
        # -- temporal model
        vi = pm.GaussianRandomWalk(
            "vi",
            mu=0,
            # anchor firmly to the previous election result
            init_dist=pm.Normal.dist(mu=start, sigma=0.0001),
            sigma=0.75,
            dims="months",
        )

        # -- likelihood / observational model
        y = pm.Normal(
            "y",
            mu=vi[params["poll_month"]],
            sigma=1,
            observed=params["y"],
            dims="polls",
        )

    return model


# quick check
_test = build_model("Australia")

### Run the model for each State

In [None]:
def run_the_model(make_map: bool = False) -> dict[str, tuple]:
    """Run the model for each state and return the results."""

    results: dict[str, tuple] = {}
    for state in STATES_DATA.keys():
        a_model = build_model(state)

        if make_map:
            bayes_tools.generate_model_map(
                a_model, f"dirichlet_state_vi_{state}", MODEL_DIR, display_images=True
            )

        with a_model:
            sampling = sampler_settings()
            idata, glitches = bayes_tools.draw_samples(a_model, **sampling)
            results[state] = idata, glitches

    return results


model_results = run_the_model()

## Charts

In [None]:
@cache
def _get_var(state: str, var_name: str) -> pd.DataFrame:
    """Extract the chains/draws for a specified var_name."""

    idata, _glitches = model_results[state]
    object = az.extract(idata)
    return object.transpose("sample", ...).to_dataframe()[var_name]

In [None]:
# useful reminder of the index names
for state_ in model_results:
    data_ = _get_var(state_, "vi")
    print(state_, data_.index.names)

### Raw plots

In [None]:
def plot_state_timeseries(
    state: str, 
    vi_data: pd.DataFrame, 
    glitches: str
) -> float:
    """Plot the state timeseries for the voting intention data.
    Return the median (right-most) endpoint values for each party."""

    # set-up
    percents = [2.5, 25, 47.5]
    intensities = [
        (p - min(percents)) / (max(percents) - min(percents)) for p in percents
    ]
    min_intensity = 0.25
    intensity_fracs = [c * (1.0 - min_intensity) + min_intensity for c in intensities]
    start_month: pd.Period = pd.Period(FIRST_MONTH, freq="M")

    # plot
    _fig, ax = plt.subplots()
    month_data = vi_data.unstack(level="months")
    month_data.columns = month_data.columns.droplevel().astype(int)
    medians = month_data.median()
    color = plt.get_cmap("Reds")

    for i, pct in enumerate(percents):
        quants = pct, 100 - pct
        label = f"{quants[1] - quants[0]:0.0f}% HDI"

        x, y1, y2 = [], [], []
        for month in month_data.columns:
            period = month + start_month
            vi = month_data[month]

            lower, upper = [vi.quantile(q=q / 100.0) for q in quants]
            x.append(str(period)[2:])
            y1.append(lower)
            y2.append(upper)

        intensity = intensity_fracs[i]
        ax.fill_between(
            x=x,  # type: ignore[arg-type]
            y1=y1,
            y2=y2,
            color=color(intensity),
            alpha=0.5,
            label=label,
            zorder=i + 1,
        )
    ax.axhline(
        y=previous_election[state],
        color="#333333",
        linestyle="--",
        linewidth=0.75,
        label="Previous election",
    )
    ax.text(
        x=medians.index[-1],
        y=medians.iloc[-1],
        s=f" {medians.iloc[-1]:0.1f}%",
        color="#333333",
        fontsize="xx-small",
        ha="left",
        va="center",       
    )
    ax.tick_params(axis="x", rotation=90, labelsize="x-small")
    poll_count = STATES_DATA[state]["n_polls"]
    plotting.finalise_plot(
        ax,
        title=f"Bayesian Aggregation: {state} ALP 2pp Voting Intention (Monthly)",
        ylabel="Per cent first preference votes",
        xlabel="Year-Month",
        legend={"loc": "upper right", "fontsize": "xx-small", "ncol": 4},
        lfooter="Data sourced from Wikipedia. House effects ignored. GRW. "
        + f"Based on {poll_count} poll{'s' if poll_count > 1 else ''}. "
        + "2pp=Pollster Estimates. ",
        rfooter="marktheballot.blogspot.com",
        rheader=glitches if glitches else None,
        y50=True,
        show=True,
    )

    return medians.iloc[-1] - previous_election[state]

In [None]:
def plot_changes(movements: dict[str, float]) -> None:
    """Plot the changes in the voting intention data
    since the 2022 Election."""

    series = pd.Series(movements)
    _fig, ax = plt.subplots()
    ax.bar(
        x=series.index,
        height=series,
        color="red",
        alpha=0.67,
    )
    #labels = ax.get_xticklabels()
    #ax.set_xticklabels(labels, rotation=0)
    inc = abs(series.min() - series.max()) * 0.01
    for state, c in movements.items():
        ax.text(
            x=state,  # type: ignore[arg-type]
            y=inc if c < 0 else -inc,
            s=f"{c:.1f}",
            va="bottom" if c < 0 else "top",
            ha="center",
            color="#444444",
            fontsize="small",
        )

    plotting.finalise_plot(
        ax,
        title="Change in ALP 2pp Voting Intention since the 2022 Election",
        xlabel=None,
        ylabel="Percentage points",
        lfooter="Data sourced from Wikipedia. House effects ignored. Monthly GRW. "
        + "2pp=Pollster Estimates. ",
        rfooter="marktheballot.blogspot.com",
        y0=True,
        show=True,
    )

In [None]:
def state_plots() -> None:
    """Plot the state level results."""

    movements = {}
    for state in STATES_DATA.keys():
        if state not in model_results:
            continue
        _, glitches = model_results[state]
        vi_data: pd.DataFrame = pd.DataFrame(
            _get_var(state, "vi")
        )

        movements[state] = plot_state_timeseries(state, vi_data, glitches)
    plot_changes(movements)


state_plots()

## Finished

In [None]:
print("Finished")