In [None]:
import matplotlib.pyplot as plt

import seaborn as sns

from sklearn.datasets import make_blobs

import pandas as pd

import numpy as np

from efficient_probit_regression.probit_model import PGeneralizedProbitModel
from efficient_probit_regression import settings
from efficient_probit_regression.datasets import Example2D

In [None]:
def get_line_y(x: np.ndarray, beta: np.ndarray):
    y = -beta[0] / beta[1] * x - beta[2] / beta[1]
    return y

def get_lines_df(data_X: np.ndarray, data_y: np.ndarray, p_list: list, x_min=-5, x_max=5):
    x = np.arange(x_min, x_max, 0.01)
    df_list = []
    for p in p_list:
        model = PGeneralizedProbitModel(p=p, X=data_X, y=data_y)
        model.fit()
        beta = model.get_params()
        line_y = get_line_y(x, beta)
        cur_df = pd.DataFrame(columns=["x1", "x2", "p"])
        cur_df["x1"] = x
        cur_df["x2"] = line_y
        cur_df["p"] = p
        df_list.append(cur_df)

    return pd.concat(df_list, ignore_index=True)

In [None]:
dataset = Example2D()
X, y = dataset.get_X(), dataset.get_y()

df = pd.DataFrame(X[:,[0, 1]], columns=["x1", "x2"])
df["y"] = np.where(y==-1, 0, y)

lines_df = get_lines_df(X, y, p_list = [1, 1.5, 2, 3, 5])

# use TeX for typesetting
plt.rcParams["text.usetex"] = True
plt.rc("font", size=15)

fig, ax = plt.subplots()
sns.scatterplot(data=df, x="x1", y="x2", hue="y", legend=False, ax=ax)

sns.lineplot(data=lines_df, x="x1", y="x2", hue="p", ax=ax, palette="flare")

ax.set_xlim(left=-3, right=6)
ax.set_ylim(bottom=-4, top=6)

ax.legend(loc="lower right", title="p")
ax.set_xlabel("$x_1$")
ax.set_ylabel("$x_2$")

ax.set_title("Multiple values of p", fontsize=23)

fig.tight_layout()

plt.savefig(settings.PLOTS_DIR / "2d-example.pdf")

fig.show()

In [None]:
df_list = []
for p in [1, 1.5, 2, 3, 5]:
    model = PGeneralizedProbitModel(p=p, X=X, y=y)
    model.fit()
    beta = model.get_params()
    residuals = X @ (beta / np.linalg.norm(beta))
    residuals_false = residuals[residuals * y < 0]
    cur_df = pd.DataFrame()
    cur_df["residual"] = residuals_false
    cur_df["p"] = p
    df_list.append(cur_df)

df_residuals = pd.concat(df_list, ignore_index=True)

plt.rcParams["text.usetex"] = True
plt.rc("font", size=15)

fig, ax = plt.subplots()
sns.boxplot(data = df_residuals, x = "p", y = "residual", ax=ax, palette="flare")
# sns.stripplot(data = df_residuals, x = "p", y = "residual", ax=ax, palette="flare")

ax.set_title("Distribution of the Residuals", fontsize=23)

fig.tight_layout()

plt.savefig(settings.PLOTS_DIR / "residual-plot.pdf")

fig.show()