Skip to content

Commit

Permalink
Merge 8aa25d5 into 3aa5baa
Browse files Browse the repository at this point in the history
  • Loading branch information
bqpd committed Feb 26, 2021
2 parents 3aa5baa + 8aa25d5 commit ad33978
Show file tree
Hide file tree
Showing 6 changed files with 274 additions and 5 deletions.
3 changes: 3 additions & 0 deletions docs/source/examples/performance_modeling.py
Expand Up @@ -250,3 +250,6 @@ def setup(self):
sankey # pylint: disable=pointless-statement
except (ImportError, ModuleNotFoundError):
print("Making Sankey diagrams requires the ipysankeywidget package")

from gpkit.interactive.references import referencesplot
referencesplot(M, openimmediately=False)
47 changes: 43 additions & 4 deletions docs/source/visint.rst
@@ -1,5 +1,44 @@
Visualization and Interaction
*****************************

Variable Reference Plots
========================

Code in this section uses the `CE solar model <https://github.com/convexengineering/solar/tree/gpkitdocs>`_

.. code:: python
from solar.solar import *
Vehicle = Aircraft(Npod=3, sp=True)
M = Mission(Vehicle, latitude=[20])
M.cost = M[M.aircraft.Wtotal]
sol = M.localsolve()
from gpkit.interactive.references import referencesplot
referencesplot(M, openimmediately=True)
Running the code above will produce two files in your working directory:
``referencesplot.html`` and ``referencesplot.json``, and (unless you specify
``openimmediately=False``) open the former in your web browser,
showing you something like this:

.. figure:: figures/referencesplot.png

`Click to see the interactive version of this plot. <https://web.mit.edu/eburn/www/referencesplot/referencesplot.html>`_

When a model's name is hovered over its connections are highlighted, showing in
red the other models it imports variables from to use in its constraints and in
blue the models that import variables from it.

By default connections are shown with equal width ("Unweighted").
When "Global Sensitivities" is selected, connection width is proportional to
the sensitivity of all variables in that connection to the importing model,
corresponding exactly to how much the model's cost would decrease if those variables were relaxed
in only that importing model. This can give a sense of which connections are vital to
the overall model. When "Normalized Sensitivities" is selected, that
global weight is divided by the weight of all variables in the importing model,
giving a sense of which connections are vital to each subsystem.


.. _sankey:
Sensitivity Diagrams
====================
Expand All @@ -9,9 +48,9 @@ Requirements
- Jupyter Notebook
- `ipysankeywidget <https://github.com/ricklupton/ipysankeywidget>`__
- Note that you'll need to activate these widgets on Jupyter by runnning

- ``jupyter nbextension enable --py --sys-prefix widgetsnbextension``

- ``jupyter nbextension enable --py --sys-prefix ipysankeywidget``

Example
Expand All @@ -25,10 +64,10 @@ Code in this section uses the `CE solar model <https://github.com/convexengineer
Vehicle = Aircraft(Npod=3, sp=True)
M = Mission(Vehicle, latitude=[20])
M.cost = M[M.aircraft.Wtotal]
sol = M.localsolve("mosek_cli")
sol = M.localsolve()
from gpkit.interactive.sankey import Sankey
Once the code above has been run in a Jupyter notebook, the code below will create interactive hierarchies of your model's sensitivities, like so:

.. figure:: figures/Mission.gif
Expand Down
3 changes: 2 additions & 1 deletion fulltests.sh
@@ -1,3 +1,4 @@
python -c "from gpkit.tests import run; run()"
python -c "from gpkit.tests import run; run()"
rm *.pkl
rm solution.*
rm referencesplot.*
86 changes: 86 additions & 0 deletions gpkit/interactive/references.py
@@ -0,0 +1,86 @@
"Code to make variable references plots"

import os
import shutil
import webbrowser
from collections import defaultdict


# pylint:disable=too-many-locals
def referencesplot(model, *, openimmediately=True):
"""Makes a references plot.
1) Creates the JSON file for a d3 references plot
2) Places it and the corresponding HTML file in the working directory
3) (optionally) opens that HTML file up immediately in a web browser
"""
imports = {}
totalv_ss = defaultdict(dict)
for constraint in model.flat():
for varkey in constraint.vks:
vlineage = varkey.lineagestr()
clineage = constraint.lineagestr()
if not vlineage:
vlineage = "%s [%s]" % (varkey, varkey.unitstr())
for lin in (clineage, vlineage):
if lin not in imports:
imports[lin] = set()
if vlineage != clineage:
imports[clineage].add(vlineage)
if constraint.v_ss:
totalv_ss[clineage] += constraint.v_ss

def clean_lineage(lineage, clusterdepth=2):
prelineage = ".".join(lineage.split(".")[:clusterdepth])
last = "0".join(lineage.split(".")[clusterdepth:])
return "model."+prelineage + "." + last

lines = ['jsondata = [']
for lineage, limports in imports.items():
name, short = clean_lineage(lineage), lineage.split(".")[-1]
limports = map(clean_lineage, limports)
lines.append(
' {"name":"%s","fullname":"%s","shortname":"%s","imports":%s},'
% (name, lineage, short, repr(list(limports)).replace("'", '"')))
lines[-1] = lines[-1][:-1]
lines.append("]")

if totalv_ss:
def get_total_senss(clineage, vlineage, normalize=False):
v_ss = totalv_ss[clineage]
num = sum(abs(ss) for vk, ss in v_ss.items()
if vk.lineagestr() == vlineage)
if not normalize:
return num
return num/sum(abs(ss) for ss in v_ss.values())
lines.append("globalsenss = {")
for clineage, limports in imports.items():
if not limports:
continue
limports = {vl: get_total_senss(clineage, vl) for vl in limports}
lines.append(' "%s": %s,' %
(clineage, repr(limports).replace("'", '"')))
lines[-1] = lines[-1][:-1]
lines.append("}")
lines.append("normalizedsenss = {")
for clineage, limports in imports.items():
if not limports:
continue
limports = {vl: get_total_senss(clineage, vl, normalize=True)
for vl in limports}
lines.append(' "%s": %s,' %
(clineage, repr(limports).replace("'", '"')))
lines[-1] = lines[-1][:-1]
lines.append("}")

with open("referencesplot.json", "w") as f:
f.write("\n".join(lines))

htmlfile = "referencesplot.html"
if not os.path.isfile(htmlfile):
shutil.copy(os.path.join(os.path.dirname(__file__), htmlfile), htmlfile)

if openimmediately:
webbrowser.open("file://" + os.path.join(os.getcwd(), htmlfile),
autoraise=True)
139 changes: 139 additions & 0 deletions gpkit/interactive/referencesplot.html
@@ -0,0 +1,139 @@
<!DOCTYPE html>
<meta charset="utf-8">
<title>GPkit Variable Reference Map</title>
<script src="https://d3js.org/d3.v6.min.js"></script>
<script type="text/javascript" src="referencesplot.json"></script>
<body style="font-family: Myriad Pro, sans-serif; color: #333;">
<div id="chart" style="width: 1000px; margin: 0 auto;">
<div id="controls" style="width: 500px; margin: 0 auto;">
<input type="radio" id="unweighted" name="linkwidth" checked class="control"
onchange="chooseStrokeWidth('unweighted')">
<label for="unweighted">Unweighted</label>
<input type="radio" id="global" name="linkwidth" class="control"
onchange="chooseStrokeWidth('global')">
<label for="global">Global Sensitivities</label>
<input type="radio" id="normalized" name="linkwidth" class="control"
onchange="chooseStrokeWidth('normalized')">
<label for="normalized">Normalized Sensitivities</label>
</div>
</div>
</body>
<script>
colorin = "#59ade4"
colorout = "#FA3333"
colornone = "#eee"
width = 954
radius = width / 2

function hierarchy(data, delimiter = ".") {
let root;
const map = new Map;
data.forEach(function find(data) {
const {name} = data;
if (map.has(name)) return map.get(name);
const i = name.lastIndexOf(delimiter);
map.set(name, data);
if (i >= 0) {
find({name: name.substring(0, i), children: []}).children.push(data);
data.name = name.substring(i + 1);
} else {
root = data;
}
return data;
});
return root;
}

function id(node) {
return `${node.parent ? id(node.parent) + "." : ""}${node.data.name}`;
}

function bilink(root) {
const map = new Map(root.leaves().map(d => [id(d), d]));
for (const d of root.leaves()) d.incoming = [], d.outgoing = d.data.imports.map(i => [d, map.get(i)]);
for (const d of root.leaves()) for (const o of d.outgoing) o[1].incoming.push(o);
return root;
}

line = d3.lineRadial()
.curve(d3.curveBundle.beta(0.85))
.radius(function(d) {if (d.data.fullname) return d.y + 15*((d.data.fullname.match(/\./g) || []).length - 1); return d.y})
.angle(d => d.x)

tree = d3.cluster()
.size([2 * Math.PI, radius - 250])

data = hierarchy(jsondata)

root = tree(bilink(d3.hierarchy(data)
.sort((a, b) => d3.ascending(a.height, b.height) || d3.ascending(a.data.name, b.data.name))));

svg = d3.select("#chart")
.append("svg").lower()
.attr("viewBox", [-width / 2, -width / 2, width, width]);

function chooseStrokeWidth(type) {
switch(type) {
case "global":
d3.selectAll("path").attr("stroke-width",
([i, o]) => 0.5*globalsenss[i.data.fullname][o.data.fullname]);
break;
case "normalized":
d3.selectAll("path").attr("stroke-width",
([i, o]) => 36*normalizedsenss[i.data.fullname][o.data.fullname]);
break;
default:
d3.selectAll("path").attr("stroke-width", 12);
}
}

link = svg.append("g")
.attr("stroke", colornone)
.attr("fill", "none")
.selectAll("path")
.data(root.leaves().flatMap(leaf => leaf.outgoing))
.join("path")
.style("mix-blend-mode", "multiply")
.attr("d", ([i, o]) => line(i.path(o)))
.each(function(d) { d.path = this; });

chooseStrokeWidth()

node = svg.append("g")
.attr("fill", "#333")
.selectAll("g")
.data(root.leaves())
.join("g")
.attr("transform", d => `rotate(${d.x * 180 / Math.PI - 90}) translate(${d.y - 1 + 15*((d.data.fullname.match(/\./g) || []).length - 1)},0)`)
.append("text")
.attr("font-size", d => 8*(1 + Math.pow(4, 1-(d.data.fullname.match(/\./g) || []).length)))
.attr("dy", "0.31em")
.attr("x", d => d.x < Math.PI ? 6 : -6)
.attr("text-anchor", d => d.x < Math.PI ? "start" : "end")
.attr("transform", d => d.x >= Math.PI ? "rotate(180)" : null)
.text(d => d.data.shortname)
.each(function(d) { d.text = this; })
.on("mouseover", overed)
.on("mouseout", outed)
.call(text => text.append("title").text(d => `${d.data.fullname}
(red: the ${d.outgoing.length} models this imports variables from)
(blue: the ${d.incoming.length} models that import these variables)`));

function overed(event, d) {
link.style("mix-blend-mode", null);
d3.select(this).attr("font-weight", "bold");
d3.selectAll(d.incoming.map(d => d.path)).attr("stroke", colorin).raise();
d3.selectAll(d.incoming.map(([d]) => d.text)).attr("fill", colorin).attr("font-weight", "bold");
d3.selectAll(d.outgoing.map(d => d.path)).attr("stroke", colorout).raise();
d3.selectAll(d.outgoing.map(([, d]) => d.text)).attr("fill", colorout).attr("font-weight", "bold");
}

function outed(event, d) {
link.style("mix-blend-mode", "multiply");
d3.select(this).attr("font-weight", null);
d3.selectAll(d.incoming.map(d => d.path)).attr("stroke", null);
d3.selectAll(d.incoming.map(([d]) => d.text)).attr("fill", null).attr("font-weight", null);
d3.selectAll(d.outgoing.map(d => d.path)).attr("stroke", null);
d3.selectAll(d.outgoing.map(([, d]) => d.text)).attr("fill", null).attr("font-weight", null);
}
</script>
1 change: 1 addition & 0 deletions runtests.sh
Expand Up @@ -5,3 +5,4 @@ python -c "import gpkit.tests; gpkit.tests.run()" && mv settings gpkit/env
rm *.pkl
rm *.pgz
rm solution.*
rm referencesplot.*

0 comments on commit ad33978

Please sign in to comment.