In [2]:
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import math
from pathlib import Path

# Create output path if doesn't exist
path = Path("../writeup/plots/job_design")
path.mkdir(parents=True, exist_ok=True)


def create_title_with_worker_assignments(W):
    """
    Create a title string showing task sequence grouped by workers
    """
    current_worker = W[0]
    groups = []
    current_group = []

    for task_idx, worker in enumerate(W, 1):
        if worker != current_worker:
            groups.append(current_group)
            current_group = []
            current_worker = worker
        current_group.append(str(task_idx))

    groups.append(current_group)

    worker_sections = ["[" + ",".join(group) + "]" for group in groups]
    return "Job Design " + "".join(worker_sections)


def draw_rect_square_unit(ax, x, y, t, c, h, task_idx):
    """
    Draw a task rectangle only. DO NOT draw the attached hand-off rectangle.
    Instead, return where it *would* go.
    """
    # Draw task rectangle
    rect = plt.Rectangle((x, y), t, c, fill=False, color="blue", linewidth=1)
    ax.add_patch(rect)

    # Add task index label
    center_x = x + t / 2
    center_y = y + c / 2
    ax.text(center_x, center_y, str(task_idx),
            ha="center", va="center")

    # Compute where the hand-off would be located
    attached_x = x + t
    attached_top_y = y + c

    next_pos = (attached_x + h, attached_top_y)
    coords = [(attached_x, attached_top_y)]

    return next_pos, coords, (attached_x, attached_top_y, h)


def draw_rect_square_sequence(T, C, H, W):
    """
    Draw tasks as rectangles. Only draw hand-off rectangles when worker changes.
    """
    if not (len(T) == len(C) == len(H) == len(W)):
        raise ValueError("All input vectors must have the same length")

    fig, ax = plt.subplots(figsize=(10, 8))
    current_pos = (0, 0)
    all_coords = [(0, 0)]

    worker_sections = []
    transition_rects = []

    current_section_start_x = 0
    current_section_start_y = 0
    current_worker = W[0]
    current_worker_T = []
    current_worker_C = []

    # Draw all tasks
    for i, (t, c, h, w) in enumerate(zip(T, C, H, W)):
        next_pos, new_coords, handoff_anchor = draw_rect_square_unit(
            ax,
            current_pos[0],
            current_pos[1],
            t, c, h,
            i + 1
        )

        attached_x, attached_top_y, h_val = handoff_anchor

        # If same worker continues, shift next_x back by the hand-off width
        if i < len(T) - 1 and w == W[i + 1]:
            next_pos = (next_pos[0] - h, next_pos[1])

        # Collect worker stats
        current_worker_T.append(t)
        current_worker_C.append(c)

        # If worker changes OR this is last task
        if (i < len(T) - 1 and w != W[i + 1]) or i == len(T) - 1:

            # If worker changes, record a hand-off rectangle
            if i < len(T) - 1 and w != W[i + 1]:
                total_c = sum(current_worker_C)
                transition_rects.append(
                    (attached_x, attached_top_y, h_val, total_c, i + 1)
                )

            # Record the worker section
            worker_sections.append(
                (
                    current_section_start_x,
                    current_section_start_y,
                    next_pos[0],
                    current_worker,
                    sum(current_worker_T),
                    sum(current_worker_C)
                )
            )

            # Reset for next worker
            if i < len(T) - 1:
                current_section_start_x = next_pos[0]
                current_section_start_y = next_pos[1]
                current_worker = W[i + 1]
                current_worker_T = []
                current_worker_C = []

        current_pos = next_pos
        all_coords.extend(new_coords)

    # Plot bounds
    all_coords = np.array(all_coords)
    x_min, y_min = all_coords.min(axis=0)
    x_max, y_max = all_coords.max(axis=0)
    max_y = y_max + (y_max - y_min) * 0.1

    # Draw blue worker bounding boxes
    for start_x, start_y, end_x, worker_id, total_t, total_c in worker_sections:
        worker_box = plt.Rectangle(
            (start_x, start_y),
            total_t,
            total_c,
            fill=True,
            facecolor="blue",
            edgecolor="blue",
            linestyle="--",
            alpha=0.2
        )
        ax.add_patch(worker_box)

    # Draw pink hand-off rectangles ONLY for actual worker transitions
    for x, y, h_val, total_c, task_idx in transition_rects:
        if h_val > 0:
            rect = plt.Rectangle(
                (x, y - total_c),
                h_val,
                total_c,
                fill=True,
                facecolor="pink",
                edgecolor="#FF9999",
                linestyle="--",
                linewidth=1,
                alpha=0.35
            )
            ax.add_patch(rect)

            # label the hand-off
            cx = x + h_val / 2
            cy = y - total_c / 2
            ax.text(cx, cy, f"$h_{{{task_idx}}}$",
                    ha="center", va="center", fontsize=10)

    # Final formatting
    padding = max(x_max - x_min, y_max - y_min) * 0.1
    ax.set_xlim(x_min - padding, x_max + padding)
    ax.set_ylim(y_min - padding, max_y + padding - 1.75)
    ax.set_aspect("equal")
    ax.set_title(create_title_with_worker_assignments(W))
    ax.set_xticks([])
    ax.set_yticks([])

    return fig, ax


def generate_worker_assignments(n):
    """Generate all possible worker assignment patterns."""
    def rec(pos, arr):
        if pos == n:
            yield arr[:]
            return

        arr[pos] = arr[pos - 1] if pos > 0 else 1
        yield from rec(pos + 1, arr)

        arr[pos] = max(arr[:pos], default=0) + 1
        yield from rec(pos + 1, arr)

    yield from rec(0, [0] * n)


# ============================
# Example usage
# ============================

handoff_height = 0.025
T = np.array([6, 1, 6])
C = np.array([2.5, 7.5, 2.5]) # middle should actually be 5.5 but alter for visual clarity
H_handoff = np.array([15, 0.75, 0])
H_no_handoff = np.array([0, 0, 0])

for H, suffix in zip([H_handoff, H_no_handoff], ["with_handoff", "no_handoff"]):
    image_files = []

    for index, assignment in enumerate(generate_worker_assignments(len(T))):
        if index == 2 ** (len(T) - 1):
            break

        W = np.array(assignment)
        fig, ax = draw_rect_square_sequence(T, C, H, W)
        filename = f"../writeup/plots/job_design/job_design_{index}_{suffix}.png"
        image_files.append(filename)
        plt.savefig(filename, dpi=100)
        plt.close()

    # Combine into grid
    grid_size = math.ceil(math.sqrt(len(image_files)))
    images = [Image.open(img) for img in image_files[::-1]]
    w, h = images[0].size
    canvas = Image.new("RGB", (w * grid_size, h * grid_size), "white")

    for i, img in enumerate(images):
        x = (i % grid_size) * w
        y = (i // grid_size) * h
        canvas.paste(img, (x, y))

    canvas.save(f"../writeup/plots/combined_grid_{suffix}.png")