In [22]:
# --- Pretty architecture diagrams (ResNet18/34 + MPID + Residual block) ---
# Drop this into a Jupyter notebook cell and run.
#
# Outputs:
#   /home/hep/an1522/dark_tridents_wspace/outputs/network_visualisation/
#     resnet18_bn_screenshot_style.png
#     resnet18_gn_screenshot_style.png
#     resnet34_bn_screenshot_style.png
#     resnet34_gn_screenshot_style.png
#     mpid_binary_screenshot_style.png
#     residual_block_schematic.png
#
# Notes:
# - No titles (as requested).
# - Consistent font size across blocks.
# - Gaps + arrows only at selected transitions.
# - Arrows are proper arrows (short line + head), pointing DOWN.
# - MPID FC layers fixed to: 12288 -> 1536 -> Sigmoid.
# - ResNet head: "Fully Connected Layer, 512 nodes" -> Sigmoid.
# - Global AvgPool label: "Global AvgPool, 16×16".
# - Residual block uses Option B: skip path line re-joining before final ReLU (no "Add (skip)" box).

from pathlib import Path
import matplotlib.pyplot as plt
from matplotlib.patches import FancyBboxPatch, Circle
from matplotlib.lines import Line2D

# -------------------------
# Paths
# -------------------------
OUTDIR = Path("/home/hep/an1522/dark_tridents_wspace/outputs/masters_poster/network_visualisation")
OUTDIR.mkdir(parents=True, exist_ok=True)

# -------------------------
# Style
# -------------------------
COLORS = {
    "input":   "#2FA6D9",   # blue
    "conv":    "#F0B43C",   # orange
    "pool":    "#43C7C6",   # teal
    "fc":      "#1D7A1D",   # green
    "sigmoid": "#D44C8E",   # magenta
    "skip":    "#BFBFBF",   # grey
    "norm":    "#43C7C6",   # teal (same as pool for simplicity)
}

TEXT_COLOR = "white"
ARROW_COLOR = "black"

# Keep ONE consistent font size across all blocks (per your request)
FS = 18

# Geometry (in a made-up "diagram coordinate system")
BLOCK_W = 10.0
BLOCK_H = 1.25
RADIUS  = 0.35

PAD_SMALL = 0.00    # tiny/none between blocks unless arrow-gap inserted
GAP_ARROW = 0.60    # extra vertical space only where arrows are drawn

MARGIN_X = 0.8
MARGIN_Y = 0.8

ARROW_LINE_LEN = 0.28  # short line part length (the arrow itself is also short)
ARROW_MUTATION_SCALE = 16
ARROW_LW = 2.0

# -------------------------
# Helpers
# -------------------------
def _rounded_box(ax, x, y, w, h, face, edge="none", lw=0, r=0.3):
    # (x, y) is bottom-left in data coords
    patch = FancyBboxPatch(
        (x, y), w, h,
        boxstyle=f"round,pad=0.02,rounding_size={r}",
        linewidth=lw,
        edgecolor=edge,
        facecolor=face,
        mutation_aspect=1.0,
    )
    ax.add_patch(patch)
    return patch

def _center_text(ax, x, y, w, h, s, fs=FS, weight="bold"):
    ax.text(
        x + w/2.0, y + h/2.0, s,
        ha="center", va="center",
        fontsize=fs,
        color=TEXT_COLOR,
        fontweight=weight,
        family="DejaVu Sans",
    )

def _draw_short_arrow(ax, x, y_top, y_bottom):
    """
    Draw a short downward arrow from y_top to y_bottom at fixed x.
    Uses annotate with arrowstyle '-|>' (line + head).
    """
    ax.annotate(
        "",
        xy=(x, y_bottom),
        xytext=(x, y_top),
        arrowprops=dict(
            arrowstyle="-|>",
            color=ARROW_COLOR,
            lw=ARROW_LW,
            mutation_scale=ARROW_MUTATION_SCALE,
            shrinkA=0,
            shrinkB=0,
        ),
        zorder=50,
    )

def draw_stack_diagram(blocks, arrow_after_indices, outfile, figsize=(6.0, 10.0), dpi=200):
    """
    blocks: list of dicts:
      {
        "text": str,
        "color_key": one of COLORS keys OR direct hex,
        "weight": "bold"/"bold"
      }
    arrow_after_indices: set of indices i meaning draw arrow gap AFTER block i
                         (i is 0-based index into blocks)
    """
    arrow_after_indices = set(arrow_after_indices)

    # Compute total height
    total_h = MARGIN_Y * 2
    for i in range(len(blocks)):
        total_h += BLOCK_H
        if i != len(blocks) - 1:
            total_h += PAD_SMALL
            if i in arrow_after_indices:
                total_h += GAP_ARROW

    total_w = MARGIN_X * 2 + BLOCK_W

    fig = plt.figure(figsize=figsize, dpi=dpi)
    ax = fig.add_axes([0, 0, 1, 1])
    ax.set_xlim(0, total_w)
    ax.set_ylim(0, total_h)
    ax.axis("off")

    # Place from top down
    x0 = MARGIN_X
    y_top = total_h - MARGIN_Y  # top edge of first block

    # Store block bottoms/tops to place arrows precisely
    block_tops = []
    block_bottoms = []

    for i, b in enumerate(blocks):
        y_bottom = y_top - BLOCK_H

        color = COLORS.get(b["color_key"], b["color_key"])
        _rounded_box(ax, x0, y_bottom, BLOCK_W, BLOCK_H, face=color, r=RADIUS)
        _center_text(ax, x0, y_bottom, BLOCK_W, BLOCK_H, b["text"], fs=FS, weight=b.get("weight", "bold"))

        block_tops.append(y_top)
        block_bottoms.append(y_bottom)

        # Move to next
        y_top = y_bottom - PAD_SMALL
        if i in arrow_after_indices:
            y_top -= GAP_ARROW

    # Draw arrows only for those chosen boundaries:
    # For boundary after block i: draw arrow in the gap between block i bottom and block i+1 top
    for i in sorted(arrow_after_indices):
        if i < 0 or i >= len(blocks) - 1:
            continue

        # Gap region is: from bottom of block i down to top of block i+1
        y1 = block_bottoms[i]          # top of gap (just below block i)
        y2 = block_tops[i+1]           # bottom side: top of next block

        # If blocks touch (no gap), there's no room to draw arrow.
        # But by construction we only inserted GAP_ARROW for these indices,
        # so y1 should be > y2.
        # We'll draw a short arrow centered in that gap.
        gap_mid = (y1 + y2) / 2.0
        half = ARROW_LINE_LEN / 2.0
        y_start = gap_mid + half
        y_end   = gap_mid - half

        cx = x0 + BLOCK_W/2.0
        _draw_short_arrow(ax, cx, y_start, y_end)

    outfile = Path(outfile)
    outfile.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(outfile, bbox_inches="tight", pad_inches=0.05, facecolor="white")
    plt.close(fig)

def draw_residual_block_schematic(outfile, norm_label="Norm", figsize=(7.2, 6.2), dpi=220):
    """
    Residual block (clear skip):
      main:  Conv3×3 -> Norm -> ReLU -> Conv3×3 -> Norm -> (+) -> ReLU
      skip:  identity ------------------------------^
    No arrows.
    """
    # Style: reuse globals you already have in the notebook
    # COLORS, FS, TEXT_COLOR, ARROW_COLOR, etc.

    fig = plt.figure(figsize=figsize, dpi=dpi)
    ax = fig.add_axes([0, 0, 1, 1])
    ax.axis("off")
    ax.set_aspect("equal", adjustable="box")

    # Geometry
    w = 5
    h = 1.10
    r = 0.32
    mx = 1.2
    my = 0.8
    pad = 0.18

    labels = [
        ("Input", "input", "bold"),
        ("Conv 3×3", "conv", "bold"),
        (norm_label, "norm", "bold"),
        ("ReLU", "conv", "bold"),
        ("Conv 3×3", "conv", "bold"),
        (norm_label, "norm", "bold"),
        # merge node (+) will be drawn here (not a box)
        ("ReLU", "conv", "bold"),
    ]

    # Total height (note: one "slot" will be used by (+) node rather than a box)
    n_boxes = 6 + 1  # 6 boxes up to second Norm + final ReLU box
    total_h = my*2 + n_boxes*h + (n_boxes-1)*pad
    total_w = mx*2 + w + 1 # extra room right for skip label/line
    ax.set_xlim(0, total_w)
    ax.set_ylim(0, total_h)
    extra_bottom = 1.5
    ax.set_ylim(-extra_bottom, total_h)

    # Helpers
    def rounded_box(x, y, ww, hh, face, r=0.3):
        patch = FancyBboxPatch(
            (x, y), ww, hh,
            boxstyle=f"round,pad=0.02,rounding_size={r}",
            linewidth=0,
            edgecolor="none",
            facecolor=face,
        )
        ax.add_patch(patch)

    def center_text(x, y, ww, hh, s, weight="bold"):
        ax.text(
            x + ww/2, y + hh/2, s,
            ha="center", va="center",
            fontsize=FS,
            color=TEXT_COLOR,
            fontweight=weight,
            family="DejaVu Sans",
        )

    # Place boxes (Input..second Norm)
    x0 = mx
    y_top = total_h - my

    # Store y positions for wiring
    box_centers = {}
    box_bottoms = {}
    box_tops = {}

    # Draw first 6 boxes (Input, Conv, Norm, ReLU, Conv, Norm)
    for i, (text, ck, wt) in enumerate(labels[:6]):
        yb = y_top - h
        rounded_box(x0, yb, w, h, face=COLORS[ck], r=r)
        center_text(x0, yb, w, h, text, weight=wt)
        box_centers[i] = (x0 + w/2, yb + h/2)
        box_bottoms[i] = yb
        box_tops[i] = y_top
        y_top = yb - pad

    # Now reserve a "merge slot" (same spacing as a box) but draw a (+) node instead
    merge_slot_top = y_top
    merge_slot_bottom = y_top - h
    merge_y = (merge_slot_top + merge_slot_bottom) / 2
    merge_x = x0 + w/2

    # Draw the merge circle with "+"
    merge_radius = 0.28
    circ = Circle((merge_x, merge_y), radius=merge_radius, facecolor="white", edgecolor=ARROW_COLOR, linewidth=2.0)
    ax.add_patch(circ)
    ax.text(merge_x, merge_y, "+", ha="center", va="center", fontsize=FS, color=ARROW_COLOR, fontweight="bold")

    # Move down past merge slot
    y_top = merge_slot_bottom - pad

    # Draw final ReLU box
    yb = y_top - h
    rounded_box(x0, yb, w, h, face=COLORS["conv"], r=r)
    center_text(x0, yb, w, h, "ReLU", weight="bold")
    final_relu_center = (x0 + w/2, yb + h/2)

    # Draw main vertical connection lines (no arrowheads)
    line_lw = 2.6
    for i in range(0, 5):
        # connect bottom center of box i to top center of box i+1
        x = x0 + w/2
        y1 = box_bottoms[i]
        y2 = box_tops[i+1]
        ax.add_line(Line2D([x, x], [y2, y1], lw=line_lw, color=ARROW_COLOR))
    # connect second Norm (box 5) to merge node
    ax.add_line(Line2D([merge_x, merge_x], [merge_slot_top, box_bottoms[5]], lw=line_lw, color=ARROW_COLOR))
    # connect merge node to final ReLU
    ax.add_line(Line2D([merge_x, merge_x], [final_relu_center[1] + h/2, merge_slot_bottom], lw=line_lw, color=ARROW_COLOR))

    # Draw skip path: from Input to merge node
    # We'll route it on the right side, with a clear "Skip" label.
    skip_x = x0 + w + 1.0  # right of the main stack
    input_center = box_centers[0]

    # horizontal out from input
    ax.add_line(Line2D([x0 + w, skip_x], [input_center[1], input_center[1]], lw=line_lw, color=ARROW_COLOR))
    # vertical down to merge_y
    ax.add_line(Line2D([skip_x, skip_x], [merge_y, input_center[1]], lw=line_lw, color=ARROW_COLOR))
    # horizontal into merge circle
    ax.add_line(Line2D([skip_x, merge_x + merge_radius], [merge_y, merge_y], lw=line_lw, color=ARROW_COLOR))

    # Optional small label (kept subtle)
    ax.text(skip_x + 0.25, (input_center[1] + merge_y)/2, "skip", rotation=90,
            ha="left", va="center", fontsize=FS-6, color=ARROW_COLOR, family="DejaVu Sans")

    outfile = Path(outfile)
    outfile.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(outfile, bbox_inches="tight", pad_inches=0.25, facecolor="white")
    plt.close(fig)

# -------------------------
# Build diagrams
# -------------------------

def make_resnet_blocks(depth):
    # depth: 18 or 34 (controls stage block counts)
    # norm_tag: "bn"/"gn" (used only for filename; you asked not to show BN/GN in the diagram)

    if depth == 18:
        stage_counts = [2, 2, 2, 2]
    elif depth == 34:
        stage_counts = [3, 4, 6, 3]
    else:
        raise ValueError("depth must be 18 or 34")

    blocks = [
        {"text": "512×512×1", "color_key": "input", "weight": "bold"},
        {"text": "Stem: 7×7 Conv, stride=2", "color_key": "conv", "weight": "bold"},
        {"text": "MaxPool 3×3, stride=2", "color_key": "pool", "weight": "bold"},
        {"text": f"{stage_counts[0]} residual blocks (64 ch)",  "color_key": "conv", "weight": "bold"},
        {"text": f"{stage_counts[1]} residual blocks (128 ch)", "color_key": "conv", "weight": "bold"},
        {"text": f"{stage_counts[2]} residual blocks (256 ch)", "color_key": "conv", "weight": "bold"},
        {"text": f"{stage_counts[3]} residual blocks (512 ch)", "color_key": "conv", "weight": "bold"},
        {"text": "Global AvgPool, 16×16", "color_key": "pool", "weight": "bold"},
        {"text": "FC Layer, 512 nodes", "color_key": "fc", "weight": "bold"},
        {"text": "Sigmoid", "color_key": "sigmoid", "weight": "bold"},
    ]

    # Arrows only:
    # Input -> Stem
    # Global AvgPool -> FC
    # FC -> Sigmoid
    arrow_after = {0, 7, 8}  # after block indices 0, 7, 8
    return blocks, arrow_after

def make_mpid_blocks():
    # Your MPID spec:
    # conv/pool stack (as in the screenshot-style diagram) + FC(12288) -> FC(1536) -> Sigmoid
    blocks = [
        {"text": "512×512×1", "color_key": "input", "weight": "bold"},

        {"text": "Conv 3×3, stride=2, 64 channel", "color_key": "conv", "weight": "bold"},
        {"text": "Conv 3×3, stride=1, 64 channel", "color_key": "conv", "weight": "bold"},
        {"text": "AvgPooling, 2×2",               "color_key": "pool", "weight": "bold"},

        {"text": "Conv 3×3, stride=1, 96 channel", "color_key": "conv", "weight": "bold"},
        {"text": "Conv 3×3, stride=1, 96 channel", "color_key": "conv", "weight": "bold"},
        {"text": "AvgPooling, 2×2",                "color_key": "pool", "weight": "bold"},

        {"text": "Conv 3×3, stride=1, 128 channel", "color_key": "conv", "weight": "bold"},
        {"text": "Conv 3×3, stride=1, 128 channel", "color_key": "conv", "weight": "bold"},
        {"text": "AvgPooling, 2×2",                 "color_key": "pool", "weight": "bold"},

        {"text": "Conv 3×3, stride=1, 160 channel", "color_key": "conv", "weight": "bold"},
        {"text": "Conv 3×3, stride=1, 160 channel", "color_key": "conv", "weight": "bold"},
        {"text": "AvgPooling, 2×2",                 "color_key": "pool", "weight": "bold"},

        {"text": "Conv 3×3, stride=1, 192 channel", "color_key": "conv", "weight": "bold"},
        {"text": "Conv 3×3, stride=1, 192 channel", "color_key": "conv", "weight": "bold"},
        {"text": "AvgPooling, 2×2",                 "color_key": "pool", "weight": "bold"},

        {"text": "FC Layer, 12,288 nodes", "color_key": "fc", "weight": "bold"},
        {"text": "FC Layer, 1,536 nodes",  "color_key": "fc", "weight": "bold"},
        {"text": "Sigmoid",                              "color_key": "sigmoid", "weight": "bold"},
    ]

    # Arrows only:
    # Input -> first Conv
    # last Pool -> FC1
    # FC1 -> FC2
    # FC2 -> Sigmoid
    arrow_after = {
        0,    # Input -> Conv1
        15,   # last AvgPool -> FC1 (index 15 is last AvgPool)
        16,   # FC1 -> FC2
        17,   # FC2 -> Sigmoid
    }
    return blocks, arrow_after

# -------------------------
# Render everything
# -------------------------

# ResNets (filenames keep bn/gn, but diagram does not print BN/GN)
for depth in (18, 34):
    blocks, arrow_after = make_resnet_blocks(depth)
    out = OUTDIR / f"resnet{depth}_screenshot_style.png"
    draw_stack_diagram(
        blocks=blocks,
        arrow_after_indices=arrow_after,
        outfile=out,
        figsize=(4.5, 6),
        dpi=220,
    )

# MPID
mpid_blocks, mpid_arrows = make_mpid_blocks()
draw_stack_diagram(
    blocks=mpid_blocks,
    arrow_after_indices=mpid_arrows,
    outfile=OUTDIR / "mpid_binary_screenshot_style.png",
    figsize=(5.5, 8),
    dpi=220,
)

# Residual block (Option B)
draw_residual_block_schematic(OUTDIR / "residual_block_schematic.png", norm_label="Norm")

print("Saved diagrams to:", OUTDIR)


Saved diagrams to: /home/hep/an1522/dark_tridents_wspace/outputs/masters_poster/network_visualisation
