In [None]:
"""Study the value function in the simple model with an identified late."""
import numpy as np
import pandas as pd  # type: ignore[import-untyped]
import plotly.graph_objects as go
from pyvmte.classes import Estimand  # type: ignore[import-untyped]
from pyvmte.config import IV_SM
from pyvmte.identification import identification  # type: ignore[import-untyped]
from pyvmte.utilities import (
    generate_bernstein_basis_funcs,
    generate_constant_splines_basis_funcs,
)

%load_ext autoreload
%autoreload 2

In [None]:
# Parameters
num_gridpoints = 100

k_bernstein = 11

shape_constraint = ("decreasing", "decreasing")
mtr_monotone = "decreasing"
monotone_response = "positive"

u_hi_extra = 0.2

pscore_lo = 0.4
pscore_hi = 0.6

target = Estimand(
    "late",
    u_lo=pscore_lo,
    u_hi=pscore_hi,
    u_hi_extra=u_hi_extra,
)

identified_sharp = [
    Estimand(esttype="cross", dz_cross=(d, z)) for d in [0, 1] for z in [0, 1]
]

identified_late = [Estimand(esttype="late", u_lo=pscore_lo, u_hi=pscore_hi)]


instrument = IV_SM

u_partition = np.unique(np.array([0, pscore_lo, pscore_hi, pscore_hi + u_hi_extra, 1]))

In [None]:
# Construct basis functions
bfuncs_constant_splines = generate_constant_splines_basis_funcs(u_partition=u_partition)
bfuncs_bernstein = generate_bernstein_basis_funcs(k=k_bernstein)

In [None]:
def _at(u: float) -> bool:
    return u <= pscore_lo


def _c(u: float) -> bool:
    return pscore_lo <= u and pscore_hi > u


def _nt(u: float) -> bool:
    return u >= pscore_hi


# The following MTR functions are decreasing and imply a decreasing MTE function.
# TODO: This is a very extreme example.
y1_at = 0.75
y0_at = 0.25

y1_nt = 0.75
y0_nt = 0.6


# Define function factories to avoid late binding
# See https://stackoverflow.com/a/3431699
def _make_m0(y0_c):
    def _m0(u):
        return y0_at * _at(u) + y0_c * _c(u) + y0_nt * _nt(u)

    return _m0


def _make_m1(y1_c):
    def _m1(u):
        return y1_at * _at(u) + y1_c * _c(u) + y1_nt * _nt(u)

    return _m1

In [None]:
beta_late = np.linspace(-1, 1, num_gridpoints)

# Construct y1_c and y0_c such that beta_late = y1_c - y0_c but both are between 0 and 1
y1_c = beta_late / 2 + 0.5
y0_c = -beta_late / 2 + 0.5

In [None]:
df_by_bfunc_and_constraint_and_idset = []

for idset in ["late", "sharp"]:
    identified = identified_late if idset == "late" else identified_sharp

    for bfunc_type in ["constant", "bernstein"]:
        for restriction in [
            "none",
            "shape_constraint",
            "mte_monotone",
            "monotone_response",
        ]:
            results = []

            shape_constraint = (
                None
                if restriction != "shape_constraint"
                else ("decreasing", "decreasing")
            )
            mte_monotone = None if restriction != "mte_monotone" else "decreasing"
            monotone_response = (
                None if restriction != "monotone_response" else "positive"
            )

            for y1_c_val, y0_c_val in zip(y1_c, y0_c, strict=True):
                if bfunc_type == "constant":
                    basis_funcs = bfuncs_constant_splines
                else:
                    basis_funcs = bfuncs_bernstein

                _m0 = _make_m0(y0_c_val)
                _m1 = _make_m1(y1_c_val)

                # The identified set might be empty for some parameter value
                # combinations.
                res = identification(
                    target=target,
                    identified_estimands=identified,
                    basis_funcs=basis_funcs,
                    instrument=instrument,
                    u_partition=u_partition,
                    m0_dgp=_m0,
                    m1_dgp=_m1,
                    shape_constraints=shape_constraint,
                    mte_monotone=mte_monotone,
                    monotone_response=monotone_response,
                )

                results.append(res)

            # Put into pandas DataFrame and save to disk

            df_by_bfunc_and_constraint_and_idset.append(
                pd.DataFrame(
                    {
                        "y1_c": y1_c,
                        "y0_c": y0_c,
                        "upper_bound": [r.upper_bound for r in results],
                        "lower_bound": [r.lower_bound for r in results],
                        "shape_constraint": [shape_constraint for _ in results],
                        "y1_at": y1_at,
                        "y0_at": y0_at,
                        "y1_nt": y1_nt,
                        "y0_nt": y0_nt,
                        "bfunc_type": bfunc_type,
                        "constraint": restriction,
                        "idset": idset,
                    },
                ),
            )

# Combine the DataFrames
df_res = pd.concat(df_by_bfunc_and_constraint_and_idset)

df_res["b_late"] = df_res["y1_c"] - df_res["y0_c"]

In [None]:
df_res[df_res["bfunc_type"] == "bernstein"]

In [None]:
for idset in ["late", "sharp"]:
    for constraint in ["none", "shape_constraint", "mte_monotone", "monotone_response"]:
        fig = go.Figure()

        bfunc_type_to_color = {"constant": "blue", "bernstein": "red"}
        bound_to_dash = {"upper_bound": "solid", "lower_bound": "solid"}

        for bfunc_type in ["constant", "bernstein"]:
            data = df_res[df_res["bfunc_type"] == bfunc_type]
            data = data[data["constraint"] == constraint]
            data = data[data["idset"] == idset]

            fig.add_trace(
                go.Scatter(
                    x=data["b_late"],
                    y=data["upper_bound"],
                    mode="lines",
                    name=f"Upper bound ({bfunc_type})",
                    legendgroup=bfunc_type,
                    legendgrouptitle={"text": bfunc_type},
                    line={
                        "color": bfunc_type_to_color[bfunc_type],
                        "dash": bound_to_dash["upper_bound"],
                    },
                ),
            )

            fig.add_trace(
                go.Scatter(
                    x=data["b_late"],
                    y=data["lower_bound"],
                    mode="lines",
                    name=f"Lower bound ({bfunc_type})",
                    legendgroup=bfunc_type,
                    line={
                        "color": bfunc_type_to_color[bfunc_type],
                        "dash": bound_to_dash["lower_bound"],
                    },
                ),
            )

        fig.update_layout(
            title=f"Bounds on Target LATE: {constraint}, {idset}",
            xaxis_title="Identified LATE",
            yaxis_title="Bounds",
        )

        fig.show()