In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal, norm
from scipy.interpolate import interp1d
from scipy.optimize import root_scalar
import ipywidgets as widgets
from IPython.display import display
from ipywidgets import FloatSlider, HBox, VBox, interactive_output, Layout
from IPython.display import HTML

In [None]:
def run_model(alpha=0.3, eta=1.0, gamma=3.0, c=1.0, tau=0.5, gov=0.5,
              sigma_q=0.5, sigma_theta=1.0, rho=0.5,
              mu_b_SH=0.5, sigma_b_SH=1,
              mu_p_SH=0.5, sigma_p_SH=1,
              b_BH=0, p_BH=0):

    def H_q_star(q_star):
        return norm.sf(q_star, loc=0, scale=sigma_q)

    def expected_theta_given_q_gt(q_star):
        z = q_star / sigma_q
        imr = norm.pdf(z) / (1 - norm.cdf(z))
        weight = sigma_theta**2 / (sigma_theta**2 + sigma_q**2)
        return weight * sigma_q * imr

    def f_q(q):
        return norm.pdf(q, loc=0, scale=sigma_q)

    def build_joint_density_and_marginals(y, q_star_val):
        grid_size = 100
        b_vals = np.linspace(-3, 3, grid_size)
        p_vals = np.linspace(-3, 3, grid_size)
        B, P = np.meshgrid(b_vals, p_vals, indexing='ij')
        mean = [mu_b_SH, mu_p_SH]
        cov = [[sigma_b_SH**2, rho * sigma_b_SH * sigma_p_SH],
               [rho * sigma_b_SH * sigma_p_SH, sigma_p_SH**2]]
        psi = multivariate_normal.pdf(np.stack([B.ravel(), P.ravel()], axis=-1), mean=mean, cov=cov).reshape(B.shape)
        H = H_q_star(q_star_val)
        x_vals = (1 / gamma) * ((P - mu_p_SH) * H + (B - mu_b_SH) * gov) - y
        phi = psi * (1 - alpha + x_vals) / (1 - alpha - y)
        phi /= np.sum(phi) * (b_vals[1] - b_vals[0]) * (p_vals[1] - p_vals[0])
        r_p = np.trapz(phi, b_vals, axis=0)
        lambda_b = np.trapz(phi, p_vals, axis=1)
        R_interp = interp1d(p_vals, np.clip(np.cumsum(r_p) * (p_vals[1] - p_vals[0]), 0, 1),
                            kind='linear', bounds_error=False, fill_value=(0, 1))
        Lambda_interp = interp1d(b_vals, np.clip(np.cumsum(lambda_b) * (b_vals[1] - b_vals[0]), 0, 1),
                                 kind='linear', bounds_error=False, fill_value=(0, 1))
        E_p_post = np.sum(P * phi) * (b_vals[1] - b_vals[0]) * (p_vals[1] - p_vals[0])
        E_b_post = np.sum(B * phi) * (b_vals[1] - b_vals[0]) * (p_vals[1] - p_vals[0])
        return R_interp, Lambda_interp, E_p_post, E_b_post


    def shareholder_support(q, R_func, alpha, y):
        return (1 - alpha - y) * (1 - R_func(-q))

    def find_q_for_support(target_s, R_func, alpha, y, q_bounds=(-5, 5)):
        def s_diff(q):
            return shareholder_support(q, R_func, alpha, y) - target_s
        try:
            sol = root_scalar(s_diff, bracket=q_bounds, method='bisect')
            return sol.root if sol.converged else None
        except:
            return None

    def compute_q_star(y):
        R_interp_temp, _, _, _ = build_joint_density_and_marginals(y, 0.2)
        q_low = find_q_for_support(tau - (alpha + y), R_interp_temp, alpha, y)
        q_high = find_q_for_support(tau, R_interp_temp, alpha, y)
        if q_low is None or q_high is None:
            return None, None, None, None
        if p_BH >= -q_low:
            return -q_low, q_low, q_high, -q_low
        elif p_BH <= -q_high:
            return -q_high, q_low, q_high, -q_high
        else:
            return p_BH, q_low, q_high, p_BH

    def dq_star_dy(y, h=1e-7):
        q_plus, _, _, _ = compute_q_star(y + h)
        q_minus, _, _, _ = compute_q_star(y - h)
        if q_plus is None or q_minus is None:
            return 0
        return (q_plus - q_minus) / (2 * h)

    def equilibrium_first_order_condition(y):
        q_star_val, _, _, _ = compute_q_star(y) 
        if q_star_val is None:
            return np.nan
        R_interp, _, E_p_SH, E_b_SH = build_joint_density_and_marginals(y, q_star_val)
        H = H_q_star(q_star_val)
        f = f_q(q_star_val)
        dqdy = dq_star_dy(y)
        MPV = (alpha * (q_star_val + p_BH) + y * (p_BH - E_p_SH)) * f * (-dqdy)
        return (p_BH - E_p_SH) * H + (b_BH - E_b_SH)*gov - (2 * gamma + eta) * y + MPV

    def compute_prices(y):
        q_star_val, q_low, q_high, median_bias = compute_q_star(y)
        if q_star_val is None:
            return [y] + [np.nan] * 11
        R_interp, Lambda_interp, E_p_SH, E_b_SH = build_joint_density_and_marginals(y, q_star_val)
        H = H_q_star(q_star_val)
        E_theta = expected_theta_given_q_gt(q_star_val)
        bar_p = (gamma / (2 * gamma + eta)) * p_BH + (1 - gamma / (2 * gamma + eta)) * E_p_SH
        bar_b = (gamma / (2 * gamma + eta)) * b_BH + (1 - gamma / (2 * gamma + eta)) * E_b_SH
        MPV = (alpha * (q_star_val + p_BH) + y * (p_BH - E_p_SH)) * f_q(q_star_val) * (-dq_star_dy(y))
        p_star = c + bar_b * gov + (bar_p + E_theta) * H + (gamma / (2 * gamma + eta)) * MPV
        bar_v = c + bar_b * gov + (bar_p + E_theta) * H
        v_BH = c + b_BH*gov + (p_BH + E_theta)*H
        s = shareholder_support(q_star_val, R_interp, alpha, y)
        return [y, p_star, bar_v, p_star - bar_v, p_star - c, s + alpha + y, s, alpha, y, median_bias, q_low, q_high,v_BH]

    # Evaluate over y grid
    y_vals = np.linspace(-0.4, 0.4, 200)
    results = np.array([compute_prices(yy) for yy in y_vals if not np.isnan(compute_prices(yy)[1])])
    if results.size == 0:
        print("Model failed to compute for current parameters.")
        return
    ys, p_stars, bar_vs, premiums, mv_votes, votes_total, s_vals, alpha_vals, y_vals_plot, med_bias, q_lows, q_highs,v_BH = results.T
    foc_vals = np.array([equilibrium_first_order_condition(yy) for yy in ys])
    dq_vals = np.array([dq_star_dy(yy) for yy in ys])

    # Compute profit for each y
    profit_vals = (alpha_vals + y_vals_plot) * v_BH - y_vals_plot * p_stars - (eta / 2) * y_vals_plot**2

    # Find the y that maximizes profit
    try:
        y_star = ys[np.argmax(profit_vals)]
    except:
        y_star = None


    # ------------------- PLOTTING -------------------
    fig, axs = plt.subplots(2, 2, figsize=(11, 9))

    # Top-left: Voting outcome vs signal
    q_range = np.linspace(-3, 3, 300)
    y_fixed = y_star if y_star is not None else 0
    q_star_val, q_low_star, q_high_star, _ = compute_q_star(y_fixed)
    R_interp_star, _, E_p_SH_star, _ = build_joint_density_and_marginals(y_fixed, q_star_val)

    s_qs = np.array([(1 - alpha - y_fixed) * (1 - R_interp_star(-qq)) for qq in q_range])
    b_qs = np.array([alpha + y_fixed if qq > -p_BH else 0 for qq in q_range])
    total_votes = s_qs + b_qs

    axs[0, 0].fill_between(q_range, 0, max(total_votes) * 1.05,
                           where=(q_range >= q_low_star) & (q_range <= q_high_star),
                           color='lightgray', alpha=0.5, label="Pivotality Region")

    axs[0, 0].plot(q_range, s_qs, '--', label="Shareholder Support", color='blue')
    axs[0, 0].plot(q_range, b_qs, ':', label="Blockholder Support", color='green')
    axs[0, 0].plot(q_range, total_votes, label="Total Votes", color='black')

    axs[0, 0].hlines(tau, -3, 3, linestyle='--', color='gray', label="Threshold $\\tau$")

    axs[0, 0].set_xlabel("Signal $q$")
    axs[0, 0].set_ylabel("Votes in Favor")
    axs[0, 0].set_title("Voting Outcome vs Signal")
    axs[0, 0].grid(True)

    # Add bias arrows pointing down to the x-axis
    axs[0, 0].annotate("$-p_{BH}$", xy=(-p_BH, 0), xytext=(-p_BH, 0.1),
                       ha='center', va='bottom',
                       arrowprops=dict(arrowstyle='-|>', color='black'),
                       fontsize=8, color='black')

    axs[0, 0].annotate(r"$-E[p_{SH}]$", xy=(-E_p_SH_star, 0), xytext=(-E_p_SH_star, 0.1),
                       ha='center', va='top',
                       arrowprops=dict(arrowstyle='-|>', color='black'),
                       fontsize=8, color='black')

    axs[0, 0].legend()


    # Top-right: Premium & MV
    axs[0, 1].plot(ys, premiums, label="Voting Premium")
    axs[0, 1].plot(ys, mv_votes, label="Market Value of Voting Rights")
    if y_star is not None:
        axs[0, 1].axvline(x=y_star, color='red', linestyle='--', label="$y^*$")
    axs[0, 1].set_xlabel("Blockholder Trade $y$")
    axs[0, 1].set_ylabel("Valuation")
    axs[0, 1].set_title("VP and MV of votes")
    axs[0, 1].legend()
    axs[0, 1].grid(True)

        # --- Bottom-left: dq*/dy and q* vs y ---
    ax1 = axs[1, 0]
    ax1.plot(ys, dq_vals, label=r"$\frac{dq^*}{dy}$", color='tab:blue')
    ax1.axhline(0, color='black', linestyle='--')
    if y_star is not None:
        ax1.axvline(x=y_star, color='red', linestyle='--', label=r"$y^*$")
    ax1.set_xlabel("Blockholder Trade $y$")
    ax1.set_ylabel(r"$\frac{dq^*}{dy}$", color='tab:blue')
    ax1.tick_params(axis='y', labelcolor='tab:blue')
    ax1.set_title("Sensitivity and Level of $q^*$")
    ax1.grid(True)

    # Create second y-axis for q*
    q_star_vals = np.array([compute_q_star(y)[0] for y in ys])
    ax2 = ax1.twinx()
    ax2.plot(ys, q_star_vals, '--', label=r"$q^*$", color='tab:orange')
    ax2.set_ylabel(r"$q^*$", color='tab:orange')
    ax2.tick_params(axis='y', labelcolor='tab:orange')

    # Combine legends from both y-axes
    lines1, labels1 = ax1.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    ax1.legend(lines1 + lines2, labels1 + labels2, loc="lower left")


    # Bottom-right: Profit only
    ax = axs[1, 1]
    profit = (alpha_vals + y_vals_plot) * v_BH - y_vals_plot * p_stars - (eta / 2) * y_vals_plot**2
    ax.plot(ys, profit, label="Profit", color='tab:blue')
    ax.set_xlabel("Blockholder Trade $y$")
    ax.set_ylabel("Blockholder Profit", color='tab:blue')
    ax.tick_params(axis='y', labelcolor='tab:blue')
    ax.set_title("Profit vs $y$")
    ax.grid(True)

    if y_star is not None:
        ax.axvline(x=y_star, color='red', linestyle='--', label="$y^*$")

    ax.legend(loc="upper left")

    plt.tight_layout()
    fig.set_dpi(200)  # Or 200 or even 300
    plt.show()

In [None]:
def run_model_lite(b_BH, p_BH, mu_b_SH, mu_p_SH, rho):
    run_model(
        b_BH=b_BH,
        p_BH=p_BH,
        mu_b_SH=mu_b_SH,
        mu_p_SH=mu_p_SH,
        rho=rho,
        # Everything else takes the default value
    )

In [None]:
# --- Interactive controls ---
# Only show these 5 sliders
# 1. Define the individual sliders
slider_b_BH = widgets.FloatSlider(value=0.5, min=-1, max=1, step=0.1, description='b_BH', continuous_update=False)
slider_p_BH = widgets.FloatSlider(value=0.2, min=-1, max=1, step=0.1, description='p_BH', continuous_update=False)
slider_mu_b_SH = widgets.FloatSlider(value=0.5, min=-1, max=1, step=0.1, description='μ_b_SH', continuous_update=False)
slider_mu_p_SH = widgets.FloatSlider(value=0.5, min=-1, max=1, step=0.1, description='μ_p_SH', continuous_update=False)
slider_rho = widgets.FloatSlider(value=0.5, min=-1, max=1, step=0.1, description='ρ', continuous_update=False)

In [None]:
# 2. Create the layout: horizontal alignment
slider_row = HBox([slider_b_BH, slider_p_BH, slider_mu_b_SH, slider_mu_p_SH, slider_rho])

In [None]:
# 3. Link to your wrapper function
def run_model_lite(b_BH, p_BH, mu_b_SH, mu_p_SH, rho):
    run_model(
        b_BH=b_BH,
        p_BH=p_BH,
        mu_b_SH=mu_b_SH,
        mu_p_SH=mu_p_SH,
        rho=rho
    )

In [None]:
interactive_ui = widgets.interactive_output(
    run_model_lite,
    {
        'b_BH': slider_b_BH,
        'p_BH': slider_p_BH,
        'mu_b_SH': slider_mu_b_SH,
        'mu_p_SH': slider_mu_p_SH,
        'rho': slider_rho
    }
)

In [None]:
# 4. Display horizontally
# Responsive CSS for Voila or notebook output
display(HTML("""
<style>
.jp-OutputArea-output > div {
    max-width: 70% !important;
}
</style>
"""))
display(VBox([slider_row, interactive_ui]))

In [None]:
'''
interactive_ui = widgets.interactive(
    run_model_lite,
    b_BH=widgets.FloatSlider(value=0.5, min=-1, max=1, step=0.1, description='b_BH'),
    p_BH=widgets.FloatSlider(value=0.2, min=-1, max=1, step=0.1, description='p_BH'),
    mu_b_SH=widgets.FloatSlider(value=0.5, min=-1, max=1, step=0.1, description='μ_b_SH'),
    mu_p_SH=widgets.FloatSlider(value=0.5, min=-1, max=1, step=0.1, description='μ_p_SH'),
    rho=widgets.FloatSlider(value=0.5, min=-1, max=1, step=0.1, description='ρ')
)

display(interactive_ui)

'''