In [19]:
from pathlib import Path
import matplotlib
matplotlib.use("Agg")

import matplotlib.pyplot as plt
from matplotlib.patches import FancyBboxPatch

ONNX_DIR = Path("/home/hep/an1522/dark_tridents_wspace/outputs/network_visualisation")
OUTDIR   = Path("/home/hep/an1522/dark_tridents_wspace/outputs/network_visualisation")
OUTDIR.mkdir(parents=True, exist_ok=True)


In [20]:
# Colour palette inspired by your screenshot
COLORS = {
    "input":  "#2aa7df",  # blue
    "conv":   "#f2b233",  # orange
    "pool":   "#44c7c5",  # teal
    "fc":     "#1f7a1f",  # green
    "sig":    "#d64b8c",  # pink
    "other":  "#dddddd",  # light grey
}

def draw_screenshot_style(blocks, outfile, width=5.2):
    """
    blocks: list of dicts with keys:
      - text: str
      - kind: one of ['input','conv','pool','fc','sig','other']
      - height: float (optional, default 1.0)
      - text_color: str (optional)
    """
    outfile = Path(outfile)
    outfile.parent.mkdir(parents=True, exist_ok=True)

    # compute total height
    h_units = sum(b.get("height", 1.0) for b in blocks)
    fig_h = max(6, 0.62*h_units)

    fig, ax = plt.subplots(figsize=(width, fig_h))
    ax.set_axis_off()

    # x0 = 0.10
    x0 = 0.04
    # w  = 0.80
    w  = 0.92
    gap = 0.10

    y = 0.95
    total = h_units + gap*(len(blocks)-1)
    unit = 0.85 / total

    for b in blocks:
        bh = b.get("height", 1.0) * unit
        y1 = y - bh

        kind = b.get("kind", "other")
        color = COLORS.get(kind, COLORS["other"])
        txt_color = b.get("text_color", "white" if kind in ("input","conv","pool","fc","sig") else "black")

        box = FancyBboxPatch(
            (x0, y1), w, bh,
            boxstyle="round,pad=0.012,rounding_size=0.02",
            linewidth=0.0,
            facecolor=color,
            transform=ax.transAxes
        )
        ax.add_patch(box)

        ax.text(
            x0 + w/2, y1 + bh/2,
            b["text"],
            transform=ax.transAxes,
            ha="center", va="center",
            fontsize=12,
            color=txt_color,
            fontweight="bold" if kind in ("input","fc","sig") else "normal"
        )

        # arrow
        y_arrow = y1 - gap*unit*0.60
        ax.annotate(
            "",
            xy=(x0+w/2, y_arrow),
            xytext=(x0+w/2, y1),
            xycoords=ax.transAxes,
            arrowprops=dict(arrowstyle="-|>", lw=2, color="black")
        )

        y = y1 - gap*unit

    fig.savefig(outfile, dpi=300, bbox_inches="tight")
    plt.close(fig)

print("renderer ready")


renderer ready


In [14]:
def resnet_blocks(depth):
    if depth == 18:
        stages = [2,2,2,2]
    elif depth == 34:
        stages = [3,4,6,3]
    else:
        raise ValueError("use 18 or 34")

    return [
        {"text": "512×512×1", "kind": "input", "height": 0.9},

        {"text": "Stem: 7×7 Conv, stride=2", "kind": "conv"},
        {"text": "MaxPool 3×3, stride=2", "kind": "pool"},

        {"text": f"Stage 1: {stages[0]} residual blocks (64 ch)",  "kind": "conv"},
        {"text": f"Stage 2: {stages[1]} residual blocks (128 ch)", "kind": "conv"},
        {"text": f"Stage 3: {stages[2]} residual blocks (256 ch)", "kind": "conv"},
        {"text": f"Stage 4: {stages[3]} residual blocks (512 ch)", "kind": "conv"},

        {"text": "Global Average Pool", "kind": "pool"},
        {"text": "Fully Connected → logits", "kind": "fc", "height": 1.0},
    ]

for depth in (18, 34):
    out = OUTDIR / f"resnet{depth}_screenshot_style.png"
    draw_screenshot_style(resnet_blocks(depth), out)
    print("wrote", out)


wrote /home/hep/an1522/dark_tridents_wspace/outputs/network_visualisation/resnet18_screenshot_style.png
wrote /home/hep/an1522/dark_tridents_wspace/outputs/network_visualisation/resnet34_screenshot_style.png


In [21]:
import onnx

def get_initializer_shapes(model):
    shapes = {}
    for init in model.graph.initializer:
        shapes[init.name] = list(init.dims)
    return shapes

def get_attrs(node):
    return {a.name: onnx.helper.get_attribute_value(a) for a in node.attribute}

def extract_mpid_summary(onnx_path):
    m = onnx.load(str(onnx_path))
    onnx.checker.check_model(m)
    init_shapes = get_initializer_shapes(m)

    # Collect convs and pools in graph order (best-effort: order as stored)
    convs = []
    pools = []
    linears = []

    for node in m.graph.node:
        if node.op_type == "Conv":
            attrs = get_attrs(node)
            s = attrs.get("strides", [1,1])
            wname = node.input[1] if len(node.input) > 1 else None
            wshape = init_shapes.get(wname, None)
            if wshape and len(wshape) == 4:
                out_ch, in_ch, kH, kW = wshape
            else:
                out_ch, in_ch, kH, kW = "?", "?", "?", "?"
            convs.append({"in": in_ch, "out": out_ch, "k": (kH,kW), "s": s})

        if node.op_type in ("AveragePool","MaxPool"):
            attrs = get_attrs(node)
            k = attrs.get("kernel_shape", ["?","?"])
            s = attrs.get("strides", ["?","?"])
            pools.append({"type": node.op_type, "k": k, "s": s})

        if node.op_type in ("Gemm","MatMul"):
            wname = node.input[1] if len(node.input) > 1 else None
            wshape = init_shapes.get(wname, None)
            if wshape and len(wshape) == 2:
                out_f, in_f = wshape
                linears.append((in_f, out_f))
            else:
                linears.append(("?", "?"))

    # Heuristic: MPID “paper” architecture is (Conv, Conv, AvgPool) repeated.
    # We’ll group convs into pairs and insert pool blocks in between.
    stages = []
    ci = 0
    pi = 0
    while ci + 1 < len(convs) and pi < len(pools):
        c1, c2 = convs[ci], convs[ci+1]
        p = pools[pi]
        stages.append((c1, c2, p))
        ci += 2
        pi += 1
        # Stop if we’ve matched the classic 5-stage pattern (often 5 pools)
        if len(stages) >= 6:
            break

    return stages, linears

mpid_path = ONNX_DIR / "mpid_binary.onnx"
stages, linears = extract_mpid_summary(mpid_path)

# Build blocks like screenshot
blocks = [{"text": "512×512×1", "kind": "input", "height": 0.9}]

for (c1, c2, p) in stages[:5]:  # your screenshot has 5 pool steps
    blocks.append({"text": f"Conv 3×3, stride={c1['s'][0]}, {c1['out']} channel", "kind": "conv"})
    blocks.append({"text": f"Conv 3×3, stride={c2['s'][0]}, {c2['out']} channel", "kind": "conv"})
    if p["type"] == "AveragePool":
        blocks.append({"text": f"AvgPooling, {p['k'][0]}×{p['k'][1]}", "kind": "pool"})
    else:
        blocks.append({"text": f"MaxPool, {p['k'][0]}×{p['k'][1]}", "kind": "pool"})

# FC layers (if present)
# Commonly linears includes (12288->1536) then (1536->2) etc.
if len(linears) >= 1:
    in_f, out_f = linears[0]
    blocks.append({"text": f"Fully Connected Layer, {in_f:,} nodes", "kind": "fc"})
if len(linears) >= 2:
    in_f, out_f = linears[1]
    blocks.append({"text": f"Fully Connected Layer, {in_f:,} nodes", "kind": "fc"})

# Sigmoid head
blocks.append({"text": "Sigmoid", "kind": "sig"})

out = OUTDIR / "mpid_binary_screenshot_style.png"
draw_screenshot_style(blocks, out, width=7.5)
print("wrote", out)



wrote /home/hep/an1522/dark_tridents_wspace/outputs/network_visualisation/mpid_binary_screenshot_style.png
