In [None]:
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import math
from pathlib import Path as PPath
from matplotlib.path import Path
from matplotlib.patches import PathPatch

# Create output path if doesn't exist
path = PPath("../writeup/plots/job_design")
path.mkdir(parents=True, exist_ok=True)


def draw_brace(ax, xy1, xy2, width, orientation='left', color='red', lw=1):
    """
    Draw a simple curly brace (quadratic)
    """
    x1, y1 = xy1
    x2, y2 = xy2

    if orientation == 'left':
        ctrl = (x1 - width, (y1 + y2) / 2)
    elif orientation == 'down':
        ctrl = ((x1 + x2) / 2, y1 - width)
    else:
        return

    verts = [xy1, ctrl, xy2]
    codes = [Path.MOVETO, Path.CURVE3, Path.CURVE3]
    path = Path(verts, codes)
    patch = PathPatch(path, fill=False, color=color, lw=lw)
    ax.add_patch(patch)


def create_title_with_worker_assignments(W):
    current_worker = W[0]
    groups = []
    current_group = []

    for task_idx, worker in enumerate(W, 1):
        if worker != current_worker:
            groups.append(current_group)
            current_group = []
            current_worker = worker
        current_group.append(str(task_idx))

    groups.append(current_group)
    worker_sections = ["[" + ",".join(group) + "]" for group in groups]
    return "Job Design " + "".join(worker_sections)


def draw_rect_square_unit(ax, x, y, t, c, h, task_idx, rectanlge_annotation=True):
    rect = plt.Rectangle((x, y), t, c, fill=False, color="blue", linewidth=1.5)
    ax.add_patch(rect)

    center_x = x + t / 2
    center_y = y + c / 2
    ax.text(center_x, center_y, str(task_idx),
            ha="center", va="center")

    attached_x = x + t
    attached_top_y = y + c
    next_pos = (attached_x + h, attached_top_y)
    coords = [(attached_x, attached_top_y)]

    if rectanlge_annotation:
        draw_brace(ax, (x, y + c), (x, y), width=0.25,
                   orientation='left', color='#1b9e77', lw=1)
        ax.text(x - 0.3, y + c / 2,
                rf'$c_{{{task_idx}}}$',
                ha='right', va='center',
                fontsize=10, color='#1b9e77')

        draw_brace(ax, (x, y), (x + t, y), width=0.25,
                   orientation='down', color='red', lw=1)
        ax.text(x + t / 2, y - 0.3,
                rf'$t_{{{task_idx}}}$',
                ha='center', va='top',
                fontsize=10, color='red')

    return next_pos, coords, (attached_x, attached_top_y, h)


def draw_rect_square_sequence(T, C, H, W,
                              rectanlge_annotation=True,
                              design_title_string=True):

    if not (len(T) == len(C) == len(H) == len(W)):
        raise ValueError("All input vectors must have the same length")

    fig, ax = plt.subplots(figsize=(10, 8))

    current_pos = (0, 0)
    all_coords = [(0, 0)]

    worker_sections = []
    transition_rects = []

    current_section_start_x = 0
    current_section_start_y = 0
    current_worker = W[0]
    current_worker_T = []
    current_worker_C = []

    for i, (t, c, h, w) in enumerate(zip(T, C, H, W)):
        next_pos, new_coords, handoff_anchor = draw_rect_square_unit(
            ax,
            current_pos[0],
            current_pos[1],
            t, c, h,
            i + 1,
            rectanlge_annotation
        )

        attached_x, attached_top_y, h_val = handoff_anchor

        if i < len(T) - 1 and w == W[i + 1]:
            next_pos = (next_pos[0] - h, next_pos[1])

        current_worker_T.append(t)
        current_worker_C.append(c)

        if (i < len(T) - 1 and w != W[i + 1]) or i == len(T) - 1:

            if i < len(T) - 1 and w != W[i + 1]:
                total_c = sum(current_worker_C)
                transition_rects.append(
                    (attached_x, attached_top_y, h_val, total_c, i + 1)
                )

            worker_sections.append(
                (
                    current_section_start_x,
                    current_section_start_y,
                    next_pos[0],
                    current_worker,
                    sum(current_worker_T),
                    sum(current_worker_C)
                )
            )

            if i < len(T) - 1:
                current_section_start_x = next_pos[0]
                current_section_start_y = next_pos[1]
                current_worker = W[i + 1]
                current_worker_T = []
                current_worker_C = []

        current_pos = next_pos
        all_coords.extend(new_coords)

    all_coords = np.array(all_coords)
    x_min, y_min = all_coords.min(axis=0)
    x_max, y_max = all_coords.max(axis=0)

    max_y = y_max + (y_max - y_min) * 0.1

    for start_x, start_y, end_x, worker_id, total_t, total_c in worker_sections:
        worker_box = plt.Rectangle(
            (start_x, start_y),
            total_t,
            total_c,
            fill=True,
            facecolor="blue",
            edgecolor="blue",
            linestyle="--",
            linewidth=2.5,   # thicker border
            alpha=0.15
        )
        ax.add_patch(worker_box)

    for x, y, h_val, total_c, task_idx in transition_rects:
        if h_val > 0:
            rect = plt.Rectangle(
                (x, y - total_c),
                h_val,
                total_c,
                fill=True,
                facecolor="pink",
                edgecolor="#FF9999",
                linestyle="--",
                linewidth=1.5,
                alpha=0.35
            )
            ax.add_patch(rect)

            cx = x + h_val / 2
            cy = y - total_c / 2
            ax.text(cx, cy, f"$h_{{{task_idx}}}$",
                    ha="center", va="center", fontsize=10)

    padding = max(x_max - x_min, y_max - y_min) * 0.125
    ax.set_xlim(x_min - padding, x_max + padding)
    ax.set_ylim(y_min - padding, max_y + padding - .9)
    ax.set_aspect("equal")

    if design_title_string:
        ax.set_title(
            create_title_with_worker_assignments(W),
            fontsize=32,
            pad = 20
        )

    ax.grid(False)
    ax.set_xticks([])
    ax.set_yticks([])

    return fig, ax, worker_sections


def generate_worker_assignments(n):
    def generate_recursive(pos, prev_assignment):
        if pos == n:
            yield prev_assignment[:]
            return

        prev_assignment[pos] = prev_assignment[pos - 1] if pos > 0 else 1
        yield from generate_recursive(pos + 1, prev_assignment)

        prev_assignment[pos] = max(prev_assignment[:pos], default=0) + 1
        yield from generate_recursive(pos + 1, prev_assignment)

    yield from generate_recursive(0, [0] * n)


# ============================
# Tent-pole tasks (3 tasks)
# ============================
handoff_height = 0
T = np.array([4, 1, 4])
C = np.array([1, 8, 1])
H = np.array([0, 0, 0])
W = np.array([1, 2, 3])

fig, ax, _ = draw_rect_square_sequence(
    T, C, H, W,
    rectanlge_annotation=True,
    design_title_string=False
)
plt.savefig("../writeup/plots/tent_pole.png", dpi=300)
plt.close()


# ============================
# Job Design with two tasks
# ============================
handoff_height = 0
T = np.array([4, 3])
C = np.array([2, 3])
H = np.array([1.5, 0])

image_files = []
for index, assignment in enumerate(generate_worker_assignments(len(T))):
    if index == 2 ** (len(T) - 1):
        break

    W = np.array(assignment)
    show_individual_annotations = (index != 0)

    fig, ax, worker_sections = draw_rect_square_sequence(
        T, C, H, W,
        rectanlge_annotation=show_individual_annotations,
        design_title_string=True
    )

    if index == 1 and H[0] > 0:
        arrow_start_x = T[0]
        arrow_end_x = T[0] + H[0]
        arrow_y = 0

        draw_brace(
            ax,
            xy1=(arrow_start_x, arrow_y),
            xy2=(arrow_end_x, arrow_y),
            width=0.25,
            orientation='down',
            color='red', lw=1
        )
        ax.text(
            (arrow_start_x + arrow_end_x) / 2,
            arrow_y - 0.3,
            r'$t_1^H$',
            ha='center', va='top',
            fontsize=10, color='red'
        )

    if index == 0:
        total_t = sum(T)
        total_c = sum(C)

        draw_brace(ax, (0, total_c), (0, 0),
                   width=0.25, orientation='left',
                   color='#1b9e77', lw=1)
        ax.text(-0.3, total_c / 2, r'$c_1+c_2$',
                ha='right', va='center',
                fontsize=10, color='#1b9e77')

        draw_brace(ax, (0, 0), (total_t, 0),
                   width=0.25, orientation='down',
                   color='red', lw=1)
        ax.text(total_t / 2, -0.3, r'$t_1+t_2$',
                ha='center', va='top',
                fontsize=10, color='red')

    filename = f"../writeup/plots/job_design/job_design_{index}.png"
    image_files.append(filename)
    plt.savefig(filename, dpi=110)
    plt.close()

# Combine 2-task images into 1Ã—2 grid <<<
target_width = 1000
target_height = 800

processed_images = []
for img_path in image_files:
    img = Image.open(img_path)
    img = img.resize((target_width, target_height), Image.LANCZOS)
    processed_images.append(img)

# Flip the order
processed_images = processed_images[::-1]

canvas_width = target_width * 2
canvas_height = target_height * 1

grid_image = Image.new("RGB", (canvas_width, canvas_height), "white")

for idx, img in enumerate(processed_images):
    x = idx * target_width
    y = 0
    grid_image.paste(img, (x, y))

grid_image.save("../writeup/plots/job_design.png")



# ============================
# Job Design with three tasks
# ============================
T = np.array([1, 2, 2])
C = np.array([3, 1, 2])
H_handoff = np.array([3, 0.5, 0])
H_no_handoff = np.array([0, 0, 0])

# T = np.array([6, 1, 6])
# C = np.array([2.5, 5.5, 2.5])
# H_handoff = np.array([16, 0.75, 0])
# H_no_handoff = np.array([0, 0, 0])

for H, my_str in zip([H_handoff, H_no_handoff],
                     ["with_handoff", "no_handoff"]):
    image_files = []
    for index, assignment in enumerate(generate_worker_assignments(len(T))):
        if index == 2 ** (len(T) - 1):
            break

        W = np.array(assignment)
        fig, ax, _ = draw_rect_square_sequence(
            T, C, H, W,
            rectanlge_annotation=False,
            design_title_string=True
        )
        filename = f"../writeup/plots/job_design/job_design_{index}_{my_str}.png"
        image_files.append(filename)
        plt.savefig(filename, dpi=300)
        plt.close()

    grid_size = math.ceil(math.sqrt(len(image_files)))
    images = [Image.open(img) for img in image_files[::-1]]
    img_width, img_height = images[-1].size
    canvas_width = img_width * grid_size
    canvas_height = img_height * grid_size
    grid_image = Image.new("RGB", (canvas_width, canvas_height), "white")

    for index, img in enumerate(images):
        x = (index % grid_size) * img_width
        y = (index // grid_size) * img_height
        grid_image.paste(img, (x, y))

    grid_image.save(f"../writeup/plots/combined_grid_{my_str}.png")