Skip to content

Commit

Permalink
Plot residual block design via graphviz
Browse files Browse the repository at this point in the history
  • Loading branch information
JiaweiZhuang committed Oct 9, 2019
1 parent 037b8c0 commit f3e647d
Showing 1 changed file with 307 additions and 0 deletions.
307 changes: 307 additions & 0 deletions notebooks/network_design/plot_resnet_block.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,307 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Ref: https://github.com/quark0/darts/blob/master/cnn/visualize.py"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from graphviz import Digraph"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def residual_block(two_blocks=False, filename='residual_block'):\n",
" # general config\n",
" g = Digraph(\n",
" format='pdf',\n",
" edge_attr=dict(fontsize='20', fontname=\"times\"),\n",
" node_attr=dict(style='filled', shape='rect', align='center', fontsize='20', height='0.5', width='0.5', penwidth='2', fontname=\"times\"),\n",
" engine='dot')\n",
" g.body.extend(['rankdir=LR'])\n",
"\n",
" # input nodes\n",
" g.node(\"c_{k-2}\", fillcolor='darkseagreen2')\n",
" g.node(\"c_{k-1}\", fillcolor='darkseagreen2')\n",
" \n",
" # intermediate nodes\n",
" for i in range(4):\n",
" g.node(str(i), fillcolor='lightblue')\n",
"\n",
" # edges\n",
" edge_list = [\n",
" (\"c_{k-1}\", '0', 'conv'),\n",
" (\"0\", '1', 'conv'),\n",
" (\"c_{k-1}\", '1', 'skip_connect') \n",
" ]\n",
" \n",
" if two_blocks:\n",
" edge_list += [\n",
" (\"1\", '2', 'conv'),\n",
" (\"2\", '3', 'conv'),\n",
" (\"1\", '3', 'skip_connect') \n",
" ]\n",
" \n",
" for u, v, op in edge_list:\n",
" g.edge(u, v, label=op, fillcolor=\"gray\")\n",
"\n",
" # output node\n",
" g.node(\"c_{k}\", fillcolor='palegoldenrod')\n",
" \n",
" if two_blocks:\n",
" g.edge('3', \"c_{k}\", fillcolor=\"gray\")\n",
" else:\n",
" g.edge('1', \"c_{k}\", fillcolor=\"gray\") \n",
" \n",
" return g "
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<!-- Generated by graphviz version 2.38.0 (20140413.2041)\n",
" -->\n",
"<!-- Title: %3 Pages: 1 -->\n",
"<svg width=\"481pt\" height=\"206pt\"\n",
" viewBox=\"0.00 0.00 481.09 206.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 202)\">\n",
"<title>%3</title>\n",
"<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-202 477.094,-202 477.094,4 -4,4\"/>\n",
"<!-- c_{k&#45;2} -->\n",
"<g id=\"node1\" class=\"node\"><title>c_{k&#45;2}</title>\n",
"<polygon fill=\"#b4eeb4\" stroke=\"black\" stroke-width=\"2\" points=\"80.6049,-36 0.131407,-36 0.131407,-0 80.6049,-0 80.6049,-36\"/>\n",
"<text text-anchor=\"middle\" x=\"40.3682\" y=\"-12\" font-family=\"times\" font-size=\"20.00\">c_{k&#45;2}</text>\n",
"</g>\n",
"<!-- c_{k&#45;1} -->\n",
"<g id=\"node2\" class=\"node\"><title>c_{k&#45;1}</title>\n",
"<polygon fill=\"#b4eeb4\" stroke=\"black\" stroke-width=\"2\" points=\"80.6049,-90 0.131407,-90 0.131407,-54 80.6049,-54 80.6049,-90\"/>\n",
"<text text-anchor=\"middle\" x=\"40.3682\" y=\"-66\" font-family=\"times\" font-size=\"20.00\">c_{k&#45;1}</text>\n",
"</g>\n",
"<!-- 0 -->\n",
"<g id=\"node3\" class=\"node\"><title>0</title>\n",
"<polygon fill=\"lightblue\" stroke=\"black\" stroke-width=\"2\" points=\"226.377,-124 190.377,-124 190.377,-88 226.377,-88 226.377,-124\"/>\n",
"<text text-anchor=\"middle\" x=\"208.377\" y=\"-100\" font-family=\"times\" font-size=\"20.00\">0</text>\n",
"</g>\n",
"<!-- c_{k&#45;1}&#45;&gt;0 -->\n",
"<g id=\"edge1\" class=\"edge\"><title>c_{k&#45;1}&#45;&gt;0</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M80.7099,-80.0575C111.351,-86.3332 153.09,-94.8817 180.136,-100.421\"/>\n",
"<polygon fill=\"gray\" stroke=\"black\" points=\"179.57,-103.877 190.069,-102.455 180.974,-97.0198 179.57,-103.877\"/>\n",
"<text text-anchor=\"middle\" x=\"118.175\" y=\"-95\" font-family=\"times\" font-size=\"20.00\">conv</text>\n",
"</g>\n",
"<!-- 1 -->\n",
"<g id=\"node4\" class=\"node\"><title>1</title>\n",
"<polygon fill=\"lightblue\" stroke=\"black\" stroke-width=\"2\" points=\"372.018,-90 336.018,-90 336.018,-54 372.018,-54 372.018,-90\"/>\n",
"<text text-anchor=\"middle\" x=\"354.018\" y=\"-66\" font-family=\"times\" font-size=\"20.00\">1</text>\n",
"</g>\n",
"<!-- c_{k&#45;1}&#45;&gt;1 -->\n",
"<g id=\"edge3\" class=\"edge\"><title>c_{k&#45;1}&#45;&gt;1</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M80.6648,-66.4154C102.656,-63.5705 130.601,-60.4185 155.613,-59 202.439,-56.3444 214.355,-55.7065 261.141,-59 283.037,-60.5414 307.653,-64.1249 325.89,-67.1399\"/>\n",
"<polygon fill=\"gray\" stroke=\"black\" points=\"325.377,-70.6027 335.822,-68.8296 326.551,-63.7019 325.377,-70.6027\"/>\n",
"<text text-anchor=\"middle\" x=\"208.377\" y=\"-63\" font-family=\"times\" font-size=\"20.00\">skip_connect</text>\n",
"</g>\n",
"<!-- 0&#45;&gt;1 -->\n",
"<g id=\"edge2\" class=\"edge\"><title>0&#45;&gt;1</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M226.418,-101.966C251.04,-96.1383 296.807,-85.305 326.044,-78.3846\"/>\n",
"<polygon fill=\"gray\" stroke=\"black\" points=\"327.073,-81.7378 335.998,-76.0285 325.461,-74.9261 327.073,-81.7378\"/>\n",
"<text text-anchor=\"middle\" x=\"298.579\" y=\"-93\" font-family=\"times\" font-size=\"20.00\">conv</text>\n",
"</g>\n",
"<!-- c_{k} -->\n",
"<g id=\"node7\" class=\"node\"><title>c_{k}</title>\n",
"<polygon fill=\"palegoldenrod\" stroke=\"black\" stroke-width=\"2\" points=\"473.132,-90 408.979,-90 408.979,-54 473.132,-54 473.132,-90\"/>\n",
"<text text-anchor=\"middle\" x=\"441.056\" y=\"-66\" font-family=\"times\" font-size=\"20.00\">c_{k}</text>\n",
"</g>\n",
"<!-- 1&#45;&gt;c_{k} -->\n",
"<g id=\"edge4\" class=\"edge\"><title>1&#45;&gt;c_{k}</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M372.226,-72C379.925,-72 389.364,-72 398.725,-72\"/>\n",
"<polygon fill=\"gray\" stroke=\"black\" points=\"398.768,-75.5001 408.768,-72 398.768,-68.5001 398.768,-75.5001\"/>\n",
"</g>\n",
"<!-- 2 -->\n",
"<g id=\"node5\" class=\"node\"><title>2</title>\n",
"<polygon fill=\"lightblue\" stroke=\"black\" stroke-width=\"2\" points=\"58.3682,-144 22.3682,-144 22.3682,-108 58.3682,-108 58.3682,-144\"/>\n",
"<text text-anchor=\"middle\" x=\"40.3682\" y=\"-120\" font-family=\"times\" font-size=\"20.00\">2</text>\n",
"</g>\n",
"<!-- 3 -->\n",
"<g id=\"node6\" class=\"node\"><title>3</title>\n",
"<polygon fill=\"lightblue\" stroke=\"black\" stroke-width=\"2\" points=\"58.3682,-198 22.3682,-198 22.3682,-162 58.3682,-162 58.3682,-198\"/>\n",
"<text text-anchor=\"middle\" x=\"40.3682\" y=\"-174\" font-family=\"times\" font-size=\"20.00\">3</text>\n",
"</g>\n",
"</g>\n",
"</svg>\n"
],
"text/plain": [
"<graphviz.dot.Digraph at 0x10ca30278>"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"g = residual_block()\n",
"g.render('residual_blocks', format='png')\n",
"g"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<!-- Generated by graphviz version 2.38.0 (20140413.2041)\n",
" -->\n",
"<!-- Title: %3 Pages: 1 -->\n",
"<svg width=\"772pt\" height=\"138pt\"\n",
" viewBox=\"0.00 0.00 772.38 138.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 134)\">\n",
"<title>%3</title>\n",
"<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-134 768.375,-134 768.375,4 -4,4\"/>\n",
"<!-- c_{k&#45;2} -->\n",
"<g id=\"node1\" class=\"node\"><title>c_{k&#45;2}</title>\n",
"<polygon fill=\"#b4eeb4\" stroke=\"black\" stroke-width=\"2\" points=\"80.6049,-36 0.131407,-36 0.131407,-0 80.6049,-0 80.6049,-36\"/>\n",
"<text text-anchor=\"middle\" x=\"40.3682\" y=\"-12\" font-family=\"times\" font-size=\"20.00\">c_{k&#45;2}</text>\n",
"</g>\n",
"<!-- c_{k&#45;1} -->\n",
"<g id=\"node2\" class=\"node\"><title>c_{k&#45;1}</title>\n",
"<polygon fill=\"#b4eeb4\" stroke=\"black\" stroke-width=\"2\" points=\"80.6049,-90 0.131407,-90 0.131407,-54 80.6049,-54 80.6049,-90\"/>\n",
"<text text-anchor=\"middle\" x=\"40.3682\" y=\"-66\" font-family=\"times\" font-size=\"20.00\">c_{k&#45;1}</text>\n",
"</g>\n",
"<!-- 0 -->\n",
"<g id=\"node3\" class=\"node\"><title>0</title>\n",
"<polygon fill=\"lightblue\" stroke=\"black\" stroke-width=\"2\" points=\"226.377,-124 190.377,-124 190.377,-88 226.377,-88 226.377,-124\"/>\n",
"<text text-anchor=\"middle\" x=\"208.377\" y=\"-100\" font-family=\"times\" font-size=\"20.00\">0</text>\n",
"</g>\n",
"<!-- c_{k&#45;1}&#45;&gt;0 -->\n",
"<g id=\"edge1\" class=\"edge\"><title>c_{k&#45;1}&#45;&gt;0</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M80.7099,-80.0575C111.351,-86.3332 153.09,-94.8817 180.136,-100.421\"/>\n",
"<polygon fill=\"gray\" stroke=\"black\" points=\"179.57,-103.877 190.069,-102.455 180.974,-97.0198 179.57,-103.877\"/>\n",
"<text text-anchor=\"middle\" x=\"118.175\" y=\"-95\" font-family=\"times\" font-size=\"20.00\">conv</text>\n",
"</g>\n",
"<!-- 1 -->\n",
"<g id=\"node4\" class=\"node\"><title>1</title>\n",
"<polygon fill=\"lightblue\" stroke=\"black\" stroke-width=\"2\" points=\"372.018,-90 336.018,-90 336.018,-54 372.018,-54 372.018,-90\"/>\n",
"<text text-anchor=\"middle\" x=\"354.018\" y=\"-66\" font-family=\"times\" font-size=\"20.00\">1</text>\n",
"</g>\n",
"<!-- c_{k&#45;1}&#45;&gt;1 -->\n",
"<g id=\"edge3\" class=\"edge\"><title>c_{k&#45;1}&#45;&gt;1</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M80.6648,-66.4154C102.656,-63.5705 130.601,-60.4185 155.613,-59 202.439,-56.3444 214.355,-55.7065 261.141,-59 283.037,-60.5414 307.653,-64.1249 325.89,-67.1399\"/>\n",
"<polygon fill=\"gray\" stroke=\"black\" points=\"325.377,-70.6027 335.822,-68.8296 326.551,-63.7019 325.377,-70.6027\"/>\n",
"<text text-anchor=\"middle\" x=\"208.377\" y=\"-63\" font-family=\"times\" font-size=\"20.00\">skip_connect</text>\n",
"</g>\n",
"<!-- 0&#45;&gt;1 -->\n",
"<g id=\"edge2\" class=\"edge\"><title>0&#45;&gt;1</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M226.418,-101.966C251.04,-96.1383 296.807,-85.305 326.044,-78.3846\"/>\n",
"<polygon fill=\"gray\" stroke=\"black\" points=\"327.073,-81.7378 335.998,-76.0285 325.461,-74.9261 327.073,-81.7378\"/>\n",
"<text text-anchor=\"middle\" x=\"298.579\" y=\"-93\" font-family=\"times\" font-size=\"20.00\">conv</text>\n",
"</g>\n",
"<!-- 2 -->\n",
"<g id=\"node5\" class=\"node\"><title>2</title>\n",
"<polygon fill=\"lightblue\" stroke=\"black\" stroke-width=\"2\" points=\"517.658,-130 481.658,-130 481.658,-94 517.658,-94 517.658,-130\"/>\n",
"<text text-anchor=\"middle\" x=\"499.658\" y=\"-106\" font-family=\"times\" font-size=\"20.00\">2</text>\n",
"</g>\n",
"<!-- 1&#45;&gt;2 -->\n",
"<g id=\"edge4\" class=\"edge\"><title>1&#45;&gt;2</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M372.059,-76.7456C396.68,-83.6019 442.448,-96.347 471.685,-104.489\"/>\n",
"<polygon fill=\"gray\" stroke=\"black\" points=\"471.066,-107.95 481.639,-107.261 472.944,-101.206 471.066,-107.95\"/>\n",
"<text text-anchor=\"middle\" x=\"409.456\" y=\"-96\" font-family=\"times\" font-size=\"20.00\">conv</text>\n",
"</g>\n",
"<!-- 3 -->\n",
"<g id=\"node6\" class=\"node\"><title>3</title>\n",
"<polygon fill=\"lightblue\" stroke=\"black\" stroke-width=\"2\" points=\"663.299,-96 627.299,-96 627.299,-60 663.299,-60 663.299,-96\"/>\n",
"<text text-anchor=\"middle\" x=\"645.299\" y=\"-72\" font-family=\"times\" font-size=\"20.00\">3</text>\n",
"</g>\n",
"<!-- 1&#45;&gt;3 -->\n",
"<g id=\"edge6\" class=\"edge\"><title>1&#45;&gt;3</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M372.202,-70.1712C406.503,-66.7903 485.886,-60.3162 552.422,-65 574.318,-66.5414 598.934,-70.1249 617.172,-73.1399\"/>\n",
"<polygon fill=\"gray\" stroke=\"black\" points=\"616.658,-76.6027 627.103,-74.8296 617.832,-69.7019 616.658,-76.6027\"/>\n",
"<text text-anchor=\"middle\" x=\"499.658\" y=\"-69\" font-family=\"times\" font-size=\"20.00\">skip_connect</text>\n",
"</g>\n",
"<!-- 2&#45;&gt;3 -->\n",
"<g id=\"edge5\" class=\"edge\"><title>2&#45;&gt;3</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M517.7,-107.966C542.321,-102.138 588.089,-91.305 617.326,-84.3846\"/>\n",
"<polygon fill=\"gray\" stroke=\"black\" points=\"618.355,-87.7378 627.279,-82.0285 616.742,-80.9261 618.355,-87.7378\"/>\n",
"<text text-anchor=\"middle\" x=\"589.86\" y=\"-99\" font-family=\"times\" font-size=\"20.00\">conv</text>\n",
"</g>\n",
"<!-- c_{k} -->\n",
"<g id=\"node7\" class=\"node\"><title>c_{k}</title>\n",
"<polygon fill=\"palegoldenrod\" stroke=\"black\" stroke-width=\"2\" points=\"764.413,-96 700.261,-96 700.261,-60 764.413,-60 764.413,-96\"/>\n",
"<text text-anchor=\"middle\" x=\"732.337\" y=\"-72\" font-family=\"times\" font-size=\"20.00\">c_{k}</text>\n",
"</g>\n",
"<!-- 3&#45;&gt;c_{k} -->\n",
"<g id=\"edge7\" class=\"edge\"><title>3&#45;&gt;c_{k}</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M663.508,-78C671.206,-78 680.645,-78 690.007,-78\"/>\n",
"<polygon fill=\"gray\" stroke=\"black\" points=\"690.049,-81.5001 700.049,-78 690.049,-74.5001 690.049,-81.5001\"/>\n",
"</g>\n",
"</g>\n",
"</svg>\n"
],
"text/plain": [
"<graphviz.dot.Digraph at 0x10ca30780>"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"g = residual_block(two_blocks=True)\n",
"g.render('two_blocks.png', format='png')\n",
"g"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit f3e647d

Please sign in to comment.