In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.gridspec as gridspec

In [None]:
def plot_joint_distribution(data, th1=0.5, th2=0.5):

    def label_map(p1, p2, th1, th2):
        if p1 >= th1 and p2 >= th2:
            l = 1
        elif p1 < th1 and p2 >= th2:
            l = 2
        elif p1 < th1 and p2 < th2:
            l = 3
        elif p1 >= th1 and p2 < th2:
            l = 4
        return l

    label = np.zeros((len(data), 3))
    label[:, :2] = data
    label[:, 2] = np.array([label_map(d[0], d[1], th1, th2) for d in label])

    fig = plt.figure(figsize=(6, 6))
    gs = gridspec.GridSpec(4, 4)

    ax_main = plt.subplot(gs[1:4, :3])
    ax_xDist = plt.subplot(gs[0, :3])
    ax_yDist = plt.subplot(gs[1:4, 3])

    ax_main.scatter(data[:, 0], data[:, 1], c=label[:, 2],
        s=1, alpha=0.5)
    ax_main.axvline(th1, color='red', linewidth=1, linestyle='--')
    ax_main.axhline(th2, color='red', linewidth=1, linestyle='--')
    ax_main.set_xticks(np.arange(0, 1.1, 0.25))
    ax_main.set_xlim([0, 1])
    ax_main.set_yticks(np.arange(0, 1.1, 0.25))
    ax_main.set_ylim([0, 1])
    ax_main.set_xlabel('$p_1 = P(Y_1 \leq 1)$')
    ax_main.set_ylabel('$p_2 = P(Y_2 \leq 2)$')

    ax_xDist = sns.kdeplot(data[:, 0], fill=True, ax=ax_xDist)
    ax_xDist = sns.kdeplot(1 - data[:, 0], fill=True, ax=ax_xDist)
    ax_xDist.axvline(th1, color='red', linewidth=1, linestyle='--')
    ax_xDist.set_xticks(np.arange(0, 1.1, 0.25))
    ax_xDist.set_xticklabels([])
    ax_xDist.set_xlim([0, 1])
    ax_xDist.set_yticks([])
    ax_xDist.set_yticklabels([])
    ax_xDist.set_ylabel(None)
    ax_xDist.spines[['right', 'top', 'left']].set_visible(False)
    ax_xDist.tick_params(which='both', top=False)

    ax_yDist = sns.kdeplot(y = data[:, 1], ax=ax_yDist, fill=True)
    ax_yDist = sns.kdeplot(y = 1 - data[:, 1], ax=ax_yDist, fill=True)
    ax_yDist.axhline(th2, color='red', linewidth=1, linestyle='--')
    ax_yDist.set_yticks(np.arange(0, 1.1, 0.25))
    ax_yDist.set_yticklabels([])
    ax_yDist.set_ylim([0, 1])
    ax_yDist.set_xticks([])
    ax_yDist.set_xticklabels([])
    ax_yDist.set_xlabel(None)
    ax_yDist.spines[['bottom', 'top', 'right']].set_visible(False)
    ax_yDist.tick_params(which='both', right=False)

    return None

In [None]:
# plot_joint_distribution(data, th1=0.4, th2=0.6)
# plt.savefig(PROJECT_PATH / 'joint_distribution.png', dpi=600)