In [None]:
"""Try out plotting by target."""
import os
from itertools import product
from pathlib import Path

import numpy as np
import pandas as pd
import plotly.graph_objects as go

from pyvmte.config import BLD, SETUP_FIG5, SIMULATION_RESULTS_DIR
from pyvmte.estimation.estimation import _compute_u_partition, _generate_basis_funcs
from pyvmte.identification.identification import identification
from pyvmte.utilities import load_paper_dgp

BY_TARGET = SIMULATION_RESULTS_DIR / "by_target"
DGP = load_paper_dgp()

In [None]:
files = [f for f in os.listdir(BY_TARGET) if Path.is_file(BY_TARGET / f)]
files

dfs = [pd.read_pickle(BY_TARGET / f).assign(filename=f) for f in files]  # noqa: S301
df_estimates = pd.concat(dfs, ignore_index=True)
df_estimates.head()

In [None]:
# From the filename column extract the string between "u_hi" and ".pkl"
df_estimates["u_hi"] = df_estimates["filename"].str.extract(r"u_hi_(.*)\.pkl")
df_estimates["u_hi"] = df_estimates["u_hi"].astype(float)
df_estimates["u_hi"].unique()

# Only keep those with u_hi in np.arange(0.35, 1, 0.05)
df_estimates = df_estimates[df_estimates["u_hi"].isin(np.arange(0.35, 1, 0.05))]
df_estimates["u_hi"].unique()

In [None]:
fig = go.Figure()

# Get unique elements of the u_hi column
unique_target_values = df_estimates["u_hi"].unique()

for target_value in unique_target_values:
    fig.add_trace(
        go.Violin(
            x=df_estimates["u_hi"][df_estimates["u_hi"] == target_value],
            y=df_estimates["upper_bound"][df_estimates["u_hi"] == target_value],
            name=target_value,
            box_visible=True,
            meanline_visible=True,
            line_color="green",
        ),
    )

    fig.add_trace(
        go.Violin(
            x=df_estimates["u_hi"][df_estimates["u_hi"] == target_value],
            y=df_estimates["lower_bound"][df_estimates["u_hi"] == target_value],
            name=target_value,
            box_visible=True,
            meanline_visible=True,
            line_color="blue",
        ),
    )

# Remove legend
fig.update_layout(showlegend=False)

# Update title
fig.update_layout(title_text="Sharp Non-Parametric Bound Estimates by Target Parameter")
fig.update_xaxes(title_text="Upper Bound of Target Parameter")
fig.update_yaxes(title_text="Bound Estimate")

fig.show()

In [None]:
INSTRUMENT = {
    "support_z": DGP["support_z"],
    "pscore_z": DGP["pscore_z"],
    "pdf_z": DGP["pdf_z"],
}

upper_bounds = np.zeros(len(unique_target_values))
lower_bounds = np.zeros(len(unique_target_values))

for u_hi_target in [0.9]:
    late_target = {
        "type": "late",
        "u_lo": 0.35,
        "u_hi": u_hi_target,
    }

    u_partition = _compute_u_partition(target=late_target, pscore_z=DGP["pscore_z"])
    bfuncs = _generate_basis_funcs("constant", u_partition)

    print(late_target)
    print(u_partition)
    print(bfuncs)
    print(SETUP_FIG5["identified_estimands"])

    combinations = product([0, 1], [0, 1, 2])

    cross_estimands = [
        {"type": "cross", "dz_cross": list(comb)} for comb in combinations
    ]

    bounds = identification(
        target=late_target,
        identified_estimands=cross_estimands,
        basis_funcs=bfuncs,
        m0_dgp=DGP["m0"],
        m1_dgp=DGP["m1"],
        u_partition=u_partition,
        instrument=INSTRUMENT,
        analytical_integration=False,
    )

    print(bounds)

    upper_bounds[u_hi_target == unique_target_values] = bounds["upper_bound"]
    lower_bounds[u_hi_target == unique_target_values] = bounds["lower_bound"]

In [None]:
print(SETUP_FIG5["identified_estimands"])

In [None]:
lower_bounds
upper_bounds

In [None]:
# Put into dataframe
df_identification = pd.DataFrame(
    {
        "u_hi": unique_target_values,
        "upper_bound": upper_bounds,
        "lower_bound": lower_bounds,
    },
)

df_identification

In [None]:
data_bounds = pd.read_pickle(  # noqa: S301
    BLD / "python" / "data" / "bounds_by_target.pickle",
)
data_bounds.head()

In [None]:
data_bounds = data_bounds[data_bounds["u_hi"] >= 0.45]  # noqa: PLR2004

In [None]:
# Plot lines for lower and upper bounds from data_bounds
fig = go.Figure()

fig.add_trace(
    go.Scatter(
        x=data_bounds["u_hi"],
        y=data_bounds["upper_bound"],
        name="Upper Bound",
        line_color="green",
    ),
)

fig.add_trace(
    go.Scatter(
        x=data_bounds["u_hi"],
        y=data_bounds["lower_bound"],
        name="Lower Bound",
        line_color="blue",
    ),
)