In [None]:
import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.mplot3d import Axes3D

# Set up color map to match your TikZ example
color_map = {
    'brown': '#A0522D',
    'orange': '#FFA500',
    'olive': '#808000',
    'yellow': '#FFD700'
}
# 2D segment data: x, y_start, y_end, color, width, label, label_pos_y
segments_2d = [
    (1, 0, .7, 'brown', 10, r'$d_1$', 0.44),
    (2, .7, 1, 'brown',  6, r'$d_2$', 0.85),
    (3, 0, .4, 'orange', 6, r'$d_3$', 0.24),
    (4, .4, 1, 'orange', 6, r'$d_4$', 0.75),
    (4, 0, .2, 'olive', 6, r'$d_5$', 0.10),
    (5, .2, .5, 'olive', 6, r'$d_6$', 0.35),
    (6, .5, 1, 'olive', 6, r'$d_7$', 0.70),
    (6, 0, .15, 'yellow', 6, r'$d_8$', 0.05),
    (7, .15, .6, 'yellow', 6, r'$d_9$', 0.35),
    (8, .6, .8, 'yellow', 6, r'$d_{10}$', 0.70),
    (9, .8, 1, 'yellow', 10, r'$d_{11}$', 0.90),
]
# 3D segment data: x, y, z_start, z_end, color, width, label, label_pos_xyz
segments_3d = [
    (1, 1, 0, .7, 'brown', 10, r'$d_1$', (1, 1.1, 0.35)),
    (1, 2, .7, 1, 'brown',  6, r'$d_2$', (1, 2.1, 0.85)),
    (1, 3, 0, .4, 'orange', 6, r'$d_3$', (1, 3.1, 0.2)),
    (2, 1, .4, 1, 'orange', 6, r'$d_4$', (2, 1.1, 0.8)),
    (2, 1, 0, .2, 'olive', 6, r'$d_5$', (2, 1.1, 0.1)),
    (2, 2, .2, .5, 'olive', 6, r'$d_6$', (2, 2.1, 0.35)),
    (2, 3, .5, 1, 'olive', 6, r'$d_7$', (2, 3.1, 0.7)),
    (2, 3, 0, .15, 'yellow', 6, r'$d_8$', (2, 3.1, 0.05)),
    (3, 1, .15, .6, 'yellow', 6, r'$d_9$', (3, 1.1, 0.35)),
    (3, 2, .6, .8, 'yellow', 6, r'$d_{10}$', (3, 2.1, 0.7)),
    (3, 3, .8, 1, 'yellow', 10, r'$d_{11}$', (2.8, 2.9, 0.9)),
]

plt.figure(figsize=(12, 10))
# --- Top plot (2D)
ax1 = plt.subplot2grid((3,1), (0,0), rowspan=1)
for x, y0, y1, c, w, lbl, yl in segments_2d:
    ax1.plot([x,x], [y0,y1], color=color_map[c], alpha=0.7, lw=w, solid_capstyle='round', zorder=2)
    ax1.text(x+0.1, yl, lbl, fontsize=14, va='center', ha='left')
ax1.set_xlim(1,9)
ax1.set_ylim(0,1)
ax1.set_xticks(np.arange(1,10))
ax1.set_yticks([])
ax1.set_xlabel("Population Units")
ax1.set_ylabel("Inclusion Probabilities")
ax1.grid(axis='y', which='major', ls='-', color='lightgray')
ax1.set_title("2D Inclusion Probability Segments")
# --- Bottom plot (3D)
ax2 = plt.subplot2grid((3,1), (1,0), rowspan=2, projection='3d')
# Draw surfaces z=0 and z=1 (light transparent black)
xx, yy = np.meshgrid([1,2,3],[1,2,3])
zz0 = np.zeros_like(xx)
zz1 = np.ones_like(xx)
ax2.plot_surface(xx, yy, zz0, color='k', alpha=0.1, shade=False, edgecolor='gray', linewidth=1)
ax2.plot_surface(xx, yy, zz1, color='k', alpha=0.1, shade=False, edgecolor='gray', linewidth=1)
# Draw grid lines in the 'surface'
for y in [2,3]:
    ax2.plot([1,3], [y,y], [0,0], lw=1, color='gray', alpha=0.4)
    ax2.plot([1,3], [y,y], [1,1], lw=1, color='gray', alpha=0.4)
# Draw all vertical segments
for x, y, z0, z1, c, w, lbl, lblpos in segments_3d:
    ax2.plot([x,x], [y,y], [z0,z1], color=color_map[c], alpha=0.7, lw=w, solid_capstyle='round', zorder=10)
    ax2.text(*lblpos, lbl, fontsize=13, va='center', ha='left')
ax2.set_xlim(0.7, 3.3)
ax2.set_ylim(0.7, 3.3)
ax2.set_zlim(0,1)
ax2.set_xticks([1,2,3])
ax2.set_yticks([1,2,3])
ax2.set_zticks([0,1])
ax2.view_init(12, 15)
ax2.set_ylabel("Population Units", labelpad=20)
ax2.set_zlabel("Inclusion Probabilities", labelpad=14)
ax2.set_title('3D Inclusion Probability Segments')
plt.tight_layout()
plt.show()
