diff --git a/book/_config.yml b/book/_config.yml
index 755efe3d..dd6472a4 100644
--- a/book/_config.yml
+++ b/book/_config.yml
@@ -1,7 +1,7 @@
# Book settings
# Learn more at https://jupyterbook.org/customize/config.html
-title: "v5.7.0"
+title: "v5.7.1"
author: Jeffrey Newman
logo: img/larch-logo.png
@@ -99,6 +99,6 @@ sphinx:
switcher:
json_url: https://larch.newman.me/_static/switcher.json
url_template: https://larch.newman.me/v{version}/
- version_match: 5.7.0
+ version_match: 5.7.1
navbar_end:
- version-switcher
diff --git a/book/_static/switcher.json b/book/_static/switcher.json
index a47cbc9c..2fc4dc84 100644
--- a/book/_static/switcher.json
+++ b/book/_static/switcher.json
@@ -1,7 +1,7 @@
[
{
- "name": "v5.7.0 (latest)",
- "version": "5.7.0"
+ "name": "v5.7.1 (latest)",
+ "version": "5.7.1"
},
{
"version": "5.4.1"
diff --git a/bumpversion.cfg b/bumpversion.cfg
index a98a299a..553c0455 100644
--- a/bumpversion.cfg
+++ b/bumpversion.cfg
@@ -1,5 +1,5 @@
[bumpversion]
-current_version = 5.7.0
+current_version = 5.7.1
commit = True
tag = True
@@ -10,4 +10,3 @@ tag = True
[bumpversion:file:book/_config.yml]
[bumpversion:file:book/_static/switcher.json]
-
diff --git a/conda-recipe/meta.yaml b/conda-recipe/meta.yaml
index 8aee32ae..05797be1 100644
--- a/conda-recipe/meta.yaml
+++ b/conda-recipe/meta.yaml
@@ -1,6 +1,6 @@
package:
name: larch
- version: "5.7.0"
+ version: "5.7.1"
source:
path: ../
@@ -34,7 +34,7 @@ requirements:
- llvm-openmp # [osx]
- {{ pin_compatible('numpy', upper_bound='1.22') }}
- scipy >=1.1
- - pandas >=0.24
+ - pandas >=0.24,<1.5
- pytables >=3.4.4 # https://github.com/conda-forge/pytables-feedstock/issues/31
- blosc >=1.14.3
- matplotlib >=3.0
diff --git a/environments/development.yml b/environments/development.yml
index bbececd9..239fec3b 100644
--- a/environments/development.yml
+++ b/environments/development.yml
@@ -31,7 +31,7 @@ dependencies:
- numexpr
- openmatrix
- openpyxl
- - pandas >=1.2
+ - pandas >=1.2,<1.5
- pillow
- pyarrow
- pydot
diff --git a/larch/__init__.py b/larch/__init__.py
index fa536e5d..0345d1ac 100644
--- a/larch/__init__.py
+++ b/larch/__init__.py
@@ -1,5 +1,5 @@
-__version__ = '5.7.0'
+__version__ = '5.7.1'
from .util.interface_info import Info, ipython_status
import sys
diff --git a/larch/model/tree.py b/larch/model/tree.py
index 97eddb48..0278557f 100644
--- a/larch/model/tree.py
+++ b/larch/model/tree.py
@@ -1,1015 +1,1194 @@
-import networkx as nx
-from collections import OrderedDict
-import numpy
import heapq
-from ..util.touch_notifier import TouchNotify
+from collections import OrderedDict
+
+import networkx as nx
+import numpy as np
+
from ..util.lazy import lazy
+
class NestingTree(nx.DiGraph):
- node_dict_factory = OrderedDict
- adjlist_dict_factory = OrderedDict
-
- def __get__(self, instance, owner):
- # self : SubkeyStore
- # instance : instance of parent class that has `self` as a member, or None
- # owner : class of `instance`
- if instance is None:
- pass # print("GRR: no instance")
- return self
- newself = getattr(instance, self.private_name, None)
- if newself is None:
- pass # print(f"GRR No Current: {instance=} {owner=}")
- try:
- instance.initialize_graph()
- except ValueError:
- pass
- newself = getattr(instance, self.private_name, None)
- if newself is not None:
- newself._instance = instance
- pass # print(f"GRR: get {instance=} {newself=}")
- return newself
-
- def __set__(self, instance, value):
- # self : NestingTree object
- # instance : instance of parent class that has `self` as a member
- # value : the new value that is trying to be assigned
- assert isinstance(value, NestingTree)
- t = value.copy()
- t._instance = instance
- setattr(instance, self.private_name, t)
- try:
- t._instance.mangle()
- except AttributeError as err:
- pass # print(f"GRR: {err}")
- else:
- pass # print(f"GRR Mangle: {instance}")
-
- def __delete__(self, instance):
- setattr(instance, self.private_name, None)
- try:
- instance.mangle()
- except AttributeError as err:
- pass # print(f"GRR: {err}")
- else:
- pass # print(f"GRR Mangle: {instance}")
-
- def __set_name__(self, owner, name):
- self.name = name
- self.private_name = "_private_"+name
-
- def touch(self):
- try:
- self._instance.mangle()
- except AttributeError:
- pass # print("GRR: mangle failure")
- else:
- pass # print("GRR: mangle ok")
-
-
- def __init__(self, *arg, root_id=0, suggested_elemental_order=(), **kwarg):
- if len(arg) and isinstance(arg[0], NestingTree):
- super().__init__(*arg, **kwarg)
- self._root_id = arg[0]._root_id
- if suggested_elemental_order != ():
- self._suggested_elemental_order = suggested_elemental_order
- else:
- self._suggested_elemental_order = arg[0]._suggested_elemental_order
- else:
- super().__init__(*arg, **kwarg)
- self._root_id = root_id
- self._suggested_elemental_order = suggested_elemental_order
- if self._root_id not in self.nodes:
- self.add_node(root_id, name='_root_', root=True)
- self._clear_caches()
-
- def __eq__(self, other):
- return (
- self._adj == other._adj and
- self._node == other._node and
- self._root_id == other._root_id and
- self._suggested_elemental_order == other._suggested_elemental_order
- )
-
- def suggest_elemental_order(self, order):
- self._suggested_elemental_order = tuple(j for j in order if j in self.nodes)
-
- @property
- def root_id(self):
- """int : The code for the root node."""
- return self._root_id
-
- @root_id.setter
- def root_id(self, x):
- top_nests = list(self.successors(self._root_id))
- top_attrs = [self.edges[self._root_id,t] for t in top_nests]
- if self._root_id in self.nodes:
- self.remove_node(self._root_id)
- self._root_id = x
- if self._root_id not in self.nodes:
- self.add_node(self._root_id, name='_root_', root=True)
- for t,a in zip(top_nests, top_attrs):
- self.add_edge(self._root_id, t, **a, _clear_caches=False)
- self._clear_caches()
-
- def _clear_caches(self):
- NestingTree.topological_sorted.invalidate(self, 'topological_sorted')
- NestingTree.topological_sorted_no_elementals.invalidate(self, 'topological_sorted_no_elementals')
- NestingTree.standard_sort.invalidate(self, 'standard_sort')
- NestingTree.standard_sort.invalidate(self, 'standard_slot_map')
- NestingTree.elementals.invalidate(self, 'elementals')
- NestingTree.standard_competitive_edge_list.invalidate(self, 'standard_competitive_edge_list')
- NestingTree.standard_competitive_edge_list_2.invalidate(self, 'standard_competitive_edge_list_2')
- self._predecessor_slots = {}
- self._successor_slots = {}
- self.touch()
-
- def add_edge(self, u, v, implied=False, _clear_caches=True, **kwarg):
- """
- Add an edge between u and v.
-
- The nodes u and v will be automatically added if they are
- not already in the graph.
-
- Edge attributes can be specified with keywords.
-
- Parameters
- ----------
- u, v : int
- Nodes should be integer codes. The upstream node `u` is
- a nest or the root node. Downsteam node `v` can be
- a nest or elemental alternative.
- implied : bool, default False
- Implied edges are for connection of otherwise unconnected
- nests to the root node.
- _clear_caches : bool, default True
- kwarg : keyword arguments, optional
- Edge data (or labels or objects) can be assigned using
- keyword arguments.
- """
- if not implied:
- drops = []
- for u_,v_,imp_ in self.in_edges(nbunch=[v], data='implied'):
- if imp_:
- drops.append([u_,v_])
- for d in drops:
- super().remove_edge(*d)
- if _clear_caches: self._clear_caches()
- return super().add_edge(int(u), int(v), implied=implied, **kwarg)
-
- def _remove_edge_no_implied(self, u, v, *arg, **kwarg):
- result = super().remove_edge(u, v)
- self._clear_caches()
- return result
-
- def remove_edge(self, u, v, *arg, **kwarg):
- """
- Remove the edge between u and v.
-
- Parameters
- ----------
- u, v : int
- Remove the edge between nodes u and v.
-
- Raises
- ------
- NetworkXError
- If there is not an edge between u and v.
- """
- result = super().remove_edge(u, v)
- if self.in_degree(v)==0 and v!=self._root_id:
- self.add_edge(self._root_id, v, implied=True, _clear_caches=False)
- self._clear_caches()
- return result
-
- def add_node(self, code, *, children=(), parent=None, parents=None, phi_parameters=None, **kwarg):
- """
- Add a single node `code` and update node attributes.
-
- Parameters
- ----------
- code : int
- Although the generic networkx.DiGraph allows a node
- to be any hashable Python object except None, Larch
- assumes that node codes are integers.
- children : Collection
- A collection of other node codes that are the children
- of this new node. Links will be created from this node
- to each child.
- parent : int, optional
- The parent of this new node. If not given, the root
- node is assumed to be the parent of this node, and an
- implied link is created. This implied link is removed
- if the node is later made the child of some other node.
- If the parent is set explicitly, the link is *not*
- removed later.
- parents : Collection, optional
- Set multiple parent up-stream nodes.
- phi_parameters : Mapping
- Set phi parameters on graph links connecting to this
- node, used in network GEV models. The keys of this mapping
- indicate the node at the other end of the link, and the
- values are parameter names.
- kwarg : other keyword arguments, optional
- Set or change node attributes using key=value.
- """
- if parents is not None and parent is not None:
- raise TypeError("cannot give both parent and parents arguments")
- super().add_node(code, **kwarg)
- for child in children:
- self.add_edge(code, child, _clear_caches=False)
- if parent is not None:
- self.add_edge(parent, code, _clear_caches=False)
- elif parents is not None:
- for p in parents:
- self.add_edge(p, code, _clear_caches=False)
- else:
- if self.in_degree(code)==0 and code!=self._root_id:
- self.add_edge(self._root_id, code, implied=True, _clear_caches=False)
- if phi_parameters is not None:
- for k, parametername in phi_parameters.items():
- if (code, k) in self.edges:
- self.edges[code, k]['parameter'] = str(parametername)
- elif (k,code) in self.edges:
- self.edges[k, code]['parameter'] = str(parametername)
- else:
- raise ValueError(f"connected node {k} from phi_parameters not found")
- self._clear_caches()
-
- def new_node(self, *, code=None, **kwarg):
- """
- Add a new nesting node to this NestingTree.
-
- A new unique code is automatically created and returned by
- this method for creating new nests.
-
- All arguments must be given as keyword parameters.
-
- Parameters
- ----------
- parameter : str
- The name of the parameter to associate with this nest.
- children : Collection[int], optional
- The code numbers for the children of this nest. These can be
- elemental alternatives or other nests. If not given, no children
- will be defined initially, but they can be added later.
- parent : int, optional
- The code number for the parent of this nest. If not given,
- the parent is implied as the root node, unless and until set
- to some other node.
- name : str, optional
- A human-readable name to associate with this nest.
- code : int, optional
- Use this code for the new nest. If this code already exists,
- a ValueError is raised.
-
- Returns
- -------
- int
- The new code for this nest.
-
- Raises
- ------
- ValueError
- If a new code is given but it already exists in this tree.
- """
- if code is None:
- proposed_code = len(self)
- while proposed_code in self:
- proposed_code += 1
- else:
- if code in self:
- raise ValueError(f'code {code} already exists in this tree')
- proposed_code = code
- self.add_node(proposed_code, **kwarg)
- return proposed_code
-
- def add_nodes(self, codes, *arg, parent=None, **kwarg):
- for code in codes:
- self.add_node(code, *arg, parent=parent, **kwarg)
-
- def remove_node(self, n):
- """
- Remove node n.
-
- Removes the node n, reconnecting all outedges to the head node
- of all inedges. Attempting to remove a non-existent node will
- raise an exception.
-
- Parameters
- ----------
- n : int
- A node in the graph
- """
- replace_edges = {
- k: self.edges[k].copy()
- for k in self.edges(n)
- }
- replace_heads = [k for k, _ in self.in_edges(n)]
- super(NestingTree, self).remove_node(n)
- for k, attrs in replace_edges.items():
- for h in replace_heads:
- super().add_edge(h, k[1], **attrs)
- self._clear_caches()
-
- @lazy
- def topological_sorted(self):
- return list(reverse_lexicographical_topological_sort(self))
-
- @lazy
- def topological_sorted_no_elementals(self):
- try:
- result = self.topological_sorted.copy()
- except nx.NetworkXUnfeasible:
- from networkx.algorithms.cycles import find_cycle
- try:
- cycle = find_cycle(self)
- except:
- pass
- else:
- print("Found graph cycle:")
- print(list(cycle))
- raise
-
- # collect zero out-degree (elemental) codes in a set
- # to remove them in a batch, which is much faster than
- # removing them from the list one at a time.
- to_remove = set()
- for code, out_degree in self.out_degree:
- if not out_degree:
- to_remove.add(code)
- return [i for i in result if i not in to_remove]
-
- @lazy
- def standard_sort(self):
- return self.elementals + tuple(self.topological_sorted_no_elementals)
-
- def node_name(self, code):
- return self.nodes[code].get('name', str(code))
-
- @property
- def standard_sort_names(self):
- return [self.node_name(s) for s in self.standard_sort]
-
- def node_names(self):
- return {s:(self.node_name(s) or s) for s in self.standard_sort}
-
- def elemental_names(self):
- return {s:(self.node_name(s) or s) for s in self.elementals}
-
- @lazy
- def standard_slot_map(self):
- return {i:n for n,i in enumerate(self.standard_sort)}
-
- def predecessor_slots(self, code):
- if code in self._predecessor_slots:
- return self._predecessor_slots[code]
- s = numpy.empty(self.in_degree(code), dtype=numpy.int32)
- for n,i in enumerate( self.predecessors(code) ):
- s[n] = self.standard_slot_map[i]
- self._predecessor_slots[code] = s
- return s
-
- def successor_slots(self, code):
- if code in self._successor_slots:
- return self._successor_slots[code]
- s = numpy.empty(self.out_degree(code), dtype=numpy.int32)
- for n,i in enumerate( self.successors(code) ):
- s[n] = self.standard_slot_map[i]
- self._successor_slots[code] = s
- return s
-
- def __elementals_iter(self):
- for code, out_degree in self.out_degree:
- if not out_degree:
- yield code
-
- @lazy
- def elementals(self):
- result = []
- found = set()
- for e in self._suggested_elemental_order:
- if self.out_degree(e)==0:
- result.append(e)
- found.add(e)
- for e in sorted(self.__elementals_iter()):
- if e not in found:
- result.append(e)
- found.add(e)
- return tuple(result)
-
- def n_elementals(self):
- return len(self.elementals)
-
- def n_intermediate_nests(self):
- return len(self.nodes) - self.n_elementals() - 1
-
- def elemental_descendants_iter(self, code):
- if not self.out_degree(code):
- yield code
- return
- all_d = nx.descendants(self, code)
- for dcode, dout_degree in self.out_degree(all_d):
- if not dout_degree:
- yield dcode
-
- def elemental_descendants(self, code):
- return [i for i in self.elemental_descendants_iter(code)]
-
- @property
- def n_edges(self):
- return self.number_of_edges()
-
- def edge_slot_arrays(self, alpha_locator=None):
- s = self.n_edges
- up = numpy.zeros(s, dtype=numpy.int32)
- dn = numpy.zeros(s, dtype=numpy.int32)
- first_visit = numpy.zeros(s, dtype=numpy.int32)
- alloc_slot = numpy.full_like(first_visit, -1)
- n = s
- first_visit_found = set()
- for upcode in reversed(self.standard_sort):
- upslot = self.standard_slot_map[upcode]
- for dnslot in reversed(self.successor_slots(upcode)):
- n -= 1
- up[n] = upslot
- dn[n] = dnslot
- for n in range(s):
- if dn[n] not in first_visit_found:
- first_visit[n] = 1
- first_visit_found.add(dn[n])
- if alpha_locator is not None:
- for n in range(s):
- alloc_slot[n] = alpha_locator.get( (self.standard_sort[up[n]], self.standard_sort[dn[n]]), -1 )
- return up, dn, first_visit, alloc_slot
-
- def nodes_with_successors_iter(self):
- for code, out_degree in self.out_degree:
- if out_degree:
- yield code
-
- def nodes_with_multiple_predecessors_iter(self):
- for code, in_degree in self.in_degree:
- if in_degree>1:
- yield code
-
- @lazy
- def standard_competitive_edge_list(self):
- alphas = []
- for n in self.nodes_with_multiple_predecessors_iter():
- predecessors = sorted(self.predecessors(n))
- for k in predecessors:
- alphas.append( (k,n) )
- return alphas
-
- @lazy
- def standard_competitive_edge_list_2(self):
- alphas = []
- for n in self.nodes_with_multiple_predecessors_iter():
- predecessors = sorted(self.predecessors(n))
- alphas.append( (predecessors, n) )
- return alphas
-
-
- def __getstate__(self):
- attr = {} # self.__dict__.copy()
- no_pickle = (
- 'topological_sorted',
- 'topological_sorted_no_elementals',
- 'standard_sort',
- 'standard_slot_map',
- #'_standard_elemental_sort',
- 'elementals',
- '_predecessor_slots',
- '_successor_slots',
- '_touch',
- 'node_dict_factory',
- '_instance',
- )
- for k,v in self.__dict__.items():
- if k not in no_pickle:
- attr[k] = v
- return attr
-
- def __setstate__(self, state):
- self.__dict__ = state.copy()
- self._predecessor_slots = {}
- self._successor_slots = {}
-
- def __xml__(self, use_viz=True, use_dot=True, output='svg', figsize=None, **format):
- viz = None
- dot = None
- if use_viz:
- try:
- import pygraphviz as viz
- except ImportError:
- if use_dot:
- try:
- import pydot as dot
- except ImportError:
- pass
- elif use_dot:
- try:
- import pydot as dot
- except ImportError:
- pass
-
- if viz is None and dot is None:
- import warnings
- if use_viz and use_dot:
- msg = "neither pydot nor pygraphviz modules are installed, unable to draw nesting tree"
- elif use_viz:
- msg = "pygraphviz module not installed, unable to draw nesting tree"
- elif use_dot:
- msg = "pydot module not installed, unable to draw nesting tree"
- else:
- msg = "no drawing module used, unable to draw nesting tree"
- warnings.warn(msg)
- raise NotImplementedError(msg)
-
- if viz is not None:
- existing_format_keys = list(format.keys())
- for key in existing_format_keys:
- if key.upper()!=key: format[key.upper()] = format[key]
- if 'SUPPRESSGRAPHSIZE' not in format:
- if 'GRAPHWIDTH' not in format: format['GRAPHWIDTH'] = 6.5
- if 'GRAPHHEIGHT' not in format: format['GRAPHHEIGHT'] = 4
- if 'UNAVAILABLE' not in format: format['UNAVAILABLE'] = True
- # x = XML_Builder("div", {'class':"nesting_graph larch_art"})
- # x.h2("Nesting Structure", anchor=1, attrib={'class':'larch_art_xhtml'})
- from io import BytesIO
- if 'SUPPRESSGRAPHSIZE' not in format:
- G=viz.AGraph(name='Tree',directed=True,size="{GRAPHWIDTH},{GRAPHHEIGHT}".format(**format))
- else:
- G=viz.AGraph(name='Tree',directed=True)
- for n in self.nodes:
- nname = self.nodes[n].get('name', n)
- if nname == n:
- G.add_node(n, label='<{1}>'.format(n,nname), style='rounded,solid', shape='box')
- else:
- G.add_node(n, label='<{1} ({0})>'.format(n,nname), style='rounded,solid', shape='box')
- eG = G.add_subgraph(name='cluster_elemental', nbunch=self.elementals, color='#cccccc', bgcolor='#eeeeee',
- label='Elemental Alternatives', labelloc='b', style='rounded,solid')
- unavailable_nodes = set()
- # if format['UNAVAILABLE']:
- # if self.is_provisioned():
- # try:
- # for n, ncode in enumerate(self.alternative_codes()):
- # if numpy.sum(self.Data('Avail'),axis=0)[n,0]==0: unavailable_nodes.add(ncode)
- # except: raise
- # try:
- # legible_avail = not isinstance(self.df.queries.avail, str)
- # except:
- # legible_avail = False
- # if legible_avail:
- # for ncode,navail in self.df.queries.avail.items():
- # try:
- # if navail=='0': unavailable_nodes.add(ncode)
- # except: raise
- # eG.add_subgraph(name='cluster_elemental_unavailable', nbunch=unavailable_nodes, color='#bbbbbb', bgcolor='#dddddd',
- # label='Unavailable Alternatives', labelloc='b', style='rounded,solid')
- G.add_node(self.root_id, label="Root")
- up_nodes = set()
- down_nodes = set()
- for i,j in self.edges:
- G.add_edge(i,j)
- down_nodes.add(j)
- up_nodes.add(i)
- pyg_imgdata = BytesIO()
- try:
- G.draw(pyg_imgdata, format=output, prog='dot') # write postscript in k5.ps with neato layout
- except ValueError as err:
- if 'in path' in str(err):
- import warnings
- warnings.warn(str(err)+"; unable to draw nesting tree in report")
- raise NotImplementedError()
- if output=='svg':
- import xml.etree.ElementTree as ET
- ET.register_namespace("","http://www.w3.org/2000/svg")
- ET.register_namespace("xlink","http://www.w3.org/1999/xlink")
- return ET.fromstring(pyg_imgdata.getvalue().decode())
- else:
- raise NotImplementedError(f"output {output} with use_viz")
- else:
-
- pydot = dot
-
- # set Graphviz graph type
- if self.is_directed():
- graph_type = 'digraph'
- else:
- graph_type = 'graph'
- strict = nx.number_of_selfloops(self) == 0 and not self.is_multigraph()
-
- name = self.name
- graph_defaults = self.graph.get('graph', {})
- if name == '':
- P = pydot.Dot('', graph_type=graph_type, strict=strict,
- **graph_defaults)
- else:
- P = pydot.Dot('"%s"' % name, graph_type=graph_type, strict=strict,
- **graph_defaults)
- try:
- P.set_node_defaults(**self.graph['node'])
- except KeyError:
- pass
- try:
- P.set_edge_defaults(**self.graph['edge'])
- except KeyError:
- pass
-
- cluster_elemental = pydot.Cluster(
- 'elemental',
- style='rounded',
- bgcolor='lightgrey',
- color='white',
- rank='same',
- rankdir="LR",
- )
-
- for n, nodedata in self.nodes(data=True):
- str_nodedata = dict((k if k!='name' else 'name_', '"'+str(v)+'"') for k, v in nodedata.items())
-
- if 'parameter' in nodedata:
- param_label = '
{0}'.format(nodedata['parameter'])
- else:
- param_label = ''
-
- if 'name' in nodedata and n != self.root_id:
- name = nodedata['name']
- str_nodedata['label'] = '<' \
- '({1}) ' \
- '{0}' \
- '{2}>'.format(name,n,param_label)
-
- # Default styling for nodes can have been overridden
- if n in self.elementals:
- str_nodedata['style'] = str_nodedata.get('style', 'filled')
- str_nodedata['fillcolor'] = str_nodedata.get('fillcolor', 'white')
- elif n == self.root_id:
- str_nodedata['shape'] = str_nodedata.get('shape', 'invhouse')
- else:
- str_nodedata['style'] = str_nodedata.get('style', 'rounded')
- str_nodedata['shape'] = str_nodedata.get('shape', 'rectangle')
-
- p = pydot.Node(str(n), **str_nodedata)
- P.add_node(p)
- if n in self.elementals:
- cluster_elemental.add_node(p)
-
- P.add_subgraph(cluster_elemental)
-
- if self.is_multigraph():
- for u, v, key, edgedata in self.edges(data=True, keys=True):
- str_edgedata = dict((k, str(v_)) for k, v_ in edgedata.items()
- if k != 'key')
- if v in self.elementals:
- str_edgedata['constraint'] = 'false'
- edge = pydot.Edge(str(u), str(v),
- key=str(key), **str_edgedata)
- P.add_edge(edge)
-
- else:
- for u, v, edgedata in self.edges(data=True):
- str_edgedata = dict((k, '"'+str(v)+'"') for k, v in edgedata.items())
- edge = pydot.Edge(str(u), str(v), **str_edgedata)
- P.add_edge(edge)
-
- ###
- from xmle import Elem
- prog = None
- if output == 'svg':
- import xml.etree.ElementTree as ET
- ET.register_namespace("","http://www.w3.org/2000/svg")
- ET.register_namespace("xlink","http://www.w3.org/1999/xlink")
- elif output == 'png':
- prog = [P.prog, '-Gdpi=300']
- if figsize is not None:
- prog.append(f"-Gsize={figsize[0]},{figsize[1]}\!")
- e = Elem.from_any(P.create(prog=prog, format=output, **format))
- e.attrib['dpi'] = (300,300)
- return e
- return Elem.from_any(P.create(prog=prog, format=output, **format))
-
- def _repr_html_(self):
- from xmle import Elem
- x = Elem('div') << (self.__xml__())
- return x.tostring()
-
- def to_png(self, figsize=None, filename=None):
- """
- Output the graph visualization as a png.
-
- Parameters
- ----------
- figsize : 2-tuple, optional
- The (width, height) in inches.
-
- Returns
- -------
- xmle.Elem
- """
- result = self.__xml__(output='png', use_viz=False, figsize=figsize)
- if filename is not None:
- import base64
- if result.attrib['src'][:22] != "data:image/png;base64,":
- raise ValueError("problem decoding png:{}".format(result.attrib['src'][:22]))
- with open(filename, "wb") as fh:
- fh.write(base64.decodebytes(result.attrib['src'][22:].encode()))
- return result
-
- def partial_figure(self, including_nodes=None, source=None, *, n=None, n_at_level=3, n_expand=1):
- """
- Generate a partial figure of the graph.
-
- Parameters
- ----------
- including_nodes : iterable or None
- An iterable containing node codes or names (or a mix).
- source : nodecode, optional
- All paths from this node to everything in `including_nodes` will be represented.
- Defaults to `root_id`.
- n : int
- If including_nodes is None, select this number of nodes randomly
-
- Returns
- -------
- Elem
- """
- if source is None:
- source = self.root_id
- from networkx.algorithms.simple_paths import all_simple_paths
- shows = set()
-
- if including_nodes is None and n is not None:
- including_nodes = sorted(numpy.random.choice(self.nodes, n, replace=False))
-
- if including_nodes is None and n_at_level is not None:
- from collections import deque
- import itertools
- q = deque([self.root_id])
- including_nodes = []
- while q:
- i = q.popleft()
- take = tuple(itertools.islice(self.successors(i), n_at_level))
- q.extend(take[:n_expand])
- including_nodes.extend(take)
-
- # Add every node in every path from the root to each `including_nodes`
- for each_node in including_nodes:
- if each_node in self.nodes:
- for i in all_simple_paths(self, source, each_node):
- for j in i:
- shows.add(j)
- else:
- for each_node_ in self.get_nodes_by_name(each_node):
- for i in all_simple_paths(self, source, each_node_):
- for j in i:
- shows.add(j)
- s = self.subgraph(shows)
- return graph_to_figure(s)
-
- def get_nodes_by_name(self, name):
- result = [k for k,v in self.nodes(data=True) if v.get('name')==name]
- return result
-
- def subgraph_from(self, node):
- from collections import deque
- Q = deque([node])
- found = set()
- while len(Q):
- i = Q.popleft()
- if i not in found:
- found.add(i)
- Q.extend(self.successors(i))
- return NestingTree(self.subgraph(found), root_id=node)
-
- def stats_summarize(self):
- print("Graph Stats")
- print(f" Overall: {len(self)} nodes")
- tier = [self.root_id]
- next_tier = list(self.successors(self.root_id))
- n = 0
- while len(next_tier):
- tier = next_tier
- n += 1
- print(f" Tier {n}: {len(tier)} nodes")
- next_tier = list()
- for i in tier:
- next_tier.extend(self.successors(i))
-
- def node_slot_arrays(self, model):
- muslots= numpy.full([len(self), ], -1, dtype=numpy.int32)
- for child, childcode in enumerate(self.standard_sort):
- # for parent in self.predecessor_slots(childcode):
- # alpha[parent, child] = 1
- pname = self.nodes[childcode].get('parameter', None)
- muslots[child] = model.get_slot_x(pname)
- num = numpy.zeros(len(self.nodes), dtype=numpy.int32)
- start = numpy.full(len(self.nodes), -1, dtype=numpy.int32)
- n = self.n_edges
- for upcode in reversed(self.standard_sort):
- upslot = self.standard_slot_map[upcode]
- for dnslot in reversed(self.successor_slots(upcode)):
- n -= 1
- num[upslot] += 1
- start[upslot] = n
- return (muslots, start, num, )
-
- def _get_simple_mu_and_alpha(self, model, holdfast_invalidates=True):
- # alpha = numpy.zeros([len(self), len(self)], dtype=numpy.float64)
- mu = numpy.ones ([len(self), ], dtype=numpy.float64)
- muslots= numpy.full([len(self), ], -1, dtype=numpy.int32)
- for child, childcode in enumerate(self.standard_sort):
- # for parent in self.predecessor_slots(childcode):
- # alpha[parent, child] = 1
- pname = self.nodes[childcode].get('parameter', None)
- mu[child] = model.get_value(pname, default=1.0)
- muslots[child] = model.get_slot_x(pname, holdfast_invalidates)
-
- s = self.n_edges
- up = numpy.zeros(s, dtype=numpy.int32)
- dn = numpy.zeros(s, dtype=numpy.int32)
- val = numpy.zeros(s, dtype=numpy.float64)
- num = numpy.zeros(len(self.nodes), dtype=numpy.int32)
- start = numpy.full(len(self.nodes), -1, dtype=numpy.int32)
- #first_visit = numpy.zeros(s, dtype=numpy.int32)
- n = s
- #first_visit_found = set()
- for upcode in reversed(self.standard_sort):
- upslot = self.standard_slot_map[upcode]
- for dnslot in reversed(self.successor_slots(upcode)):
- n -= 1
- up[n] = upslot
- dn[n] = dnslot
- num[upslot] += 1
- start[upslot] = n
- val[n] = 1/len(self.predecessor_slots(self.standard_sort[dnslot])) # TODO make not always constant fraction
- # for n in range(s):
- # if dn[n] not in first_visit_found:
- # first_visit[n] = 1
- # first_visit_found.add(dn[n])
-
- return mu, muslots, up, dn, num, start, val
-
-def graph_to_figure(graph, output_format='svg', **format):
-
- try:
- import pygraphviz as viz
- except ImportError:
- import warnings
- warnings.warn("pygraphviz module not installed, unable to draw nesting tree")
- raise NotImplementedError("pygraphviz module not installed, unable to draw nesting tree")
- existing_format_keys = list(format.keys())
- for key in existing_format_keys:
- if key.upper()!=key: format[key.upper()] = format[key]
- if 'SUPPRESSGRAPHSIZE' not in format:
- if 'GRAPHWIDTH' not in format: format['GRAPHWIDTH'] = 6.5
- if 'GRAPHHEIGHT' not in format: format['GRAPHHEIGHT'] = 4
- if 'UNAVAILABLE' not in format: format['UNAVAILABLE'] = True
- # x = XML_Builder("div", {'class':"nesting_graph larch_art"})
- # x.h2("Nesting Structure", anchor=1, attrib={'class':'larch_art_xhtml'})
- from io import BytesIO
- if 'SUPPRESSGRAPHSIZE' not in format:
- G=viz.AGraph(name='Tree',directed=True,size="{GRAPHWIDTH},{GRAPHHEIGHT}".format(**format))
- else:
- G=viz.AGraph(name='Tree',directed=True)
- for n in graph.nodes:
- nname = graph.nodes[n].get('name', n)
- if nname == n:
- G.add_node(n, label='<{1}>'.format(n,nname), style='rounded,solid', shape='box')
- else:
- G.add_node(n, label='<{1} ({0})>'.format(n,nname), style='rounded,solid', shape='box')
- try:
- graph.elementals
- except AttributeError:
- pass
- else:
- eG = G.add_subgraph(name='cluster_elemental', nbunch=graph.elementals, color='#cccccc', bgcolor='#eeeeee',
- label='Elemental Alternatives', labelloc='b', style='rounded,solid')
- unavailable_nodes = set()
- # if format['UNAVAILABLE']:
- # if self.is_provisioned():
- # try:
- # for n, ncode in enumerate(self.alternative_codes()):
- # if numpy.sum(self.Data('Avail'),axis=0)[n,0]==0: unavailable_nodes.add(ncode)
- # except: raise
- # try:
- # legible_avail = not isinstance(self.df.queries.avail, str)
- # except:
- # legible_avail = False
- # if legible_avail:
- # for ncode,navail in self.df.queries.avail.items():
- # try:
- # if navail=='0': unavailable_nodes.add(ncode)
- # except: raise
- # eG.add_subgraph(name='cluster_elemental_unavailable', nbunch=unavailable_nodes, color='#bbbbbb', bgcolor='#dddddd',
- # label='Unavailable Alternatives', labelloc='b', style='rounded,solid')
- try:
- G.add_node(graph.root_id, label="Root")
- except AttributeError:
- pass
- up_nodes = set()
- down_nodes = set()
- for i,j in graph.edges:
- G.add_edge(i,j)
- down_nodes.add(j)
- up_nodes.add(i)
- pyg_imgdata = BytesIO()
- try:
- G.draw(pyg_imgdata, format=output_format, prog='dot') # write postscript in k5.ps with neato layout
- except ValueError as err:
- if 'in path' in str(err):
- import warnings
- warnings.warn(str(err)+"; unable to draw nesting tree in report")
- raise NotImplementedError()
- from xmle import Elem
- if output_format == 'svg':
- import xml.etree.ElementTree as ET
- ET.register_namespace("","http://www.w3.org/2000/svg")
- ET.register_namespace("xlink","http://www.w3.org/1999/xlink")
- result = ET.fromstring(pyg_imgdata.getvalue().decode())
- else:
- result = Elem('span', attrib={'style':'color:red'}, text=f"Unable to render output_format '{output_format}'")
- x = Elem('div') << result
- return x
+ node_dict_factory = OrderedDict
+ adjlist_dict_factory = OrderedDict
+
+ def __get__(self, instance, owner):
+ # self : SubkeyStore
+ # instance : instance of parent class that has `self` as a member, or None
+ # owner : class of `instance`
+ if instance is None:
+ pass # print("GRR: no instance")
+ return self
+ newself = getattr(instance, self.private_name, None)
+ if newself is None:
+ pass # print(f"GRR No Current: {instance=} {owner=}")
+ try:
+ instance.initialize_graph()
+ except ValueError:
+ pass
+ newself = getattr(instance, self.private_name, None)
+ if newself is not None:
+ newself._instance = instance
+ pass # print(f"GRR: get {instance=} {newself=}")
+ return newself
+
+ def __set__(self, instance, value):
+ # self : NestingTree object
+ # instance : instance of parent class that has `self` as a member
+ # value : the new value that is trying to be assigned
+ assert isinstance(value, NestingTree)
+ t = value.copy()
+ t._instance = instance
+ setattr(instance, self.private_name, t)
+ try:
+ t._instance.mangle()
+ except AttributeError as err:
+ pass # print(f"GRR: {err}")
+ else:
+ pass # print(f"GRR Mangle: {instance}")
+
+ def __delete__(self, instance):
+ setattr(instance, self.private_name, None)
+ try:
+ instance.mangle()
+ except AttributeError as err:
+ pass # print(f"GRR: {err}")
+ else:
+ pass # print(f"GRR Mangle: {instance}")
+
+ def __set_name__(self, owner, name):
+ self.name = name
+ self.private_name = "_private_" + name
+
+ def touch(self):
+ try:
+ self._instance.mangle()
+ except AttributeError:
+ pass # print("GRR: mangle failure")
+ else:
+ pass # print("GRR: mangle ok")
+
+ def __init__(self, *arg, root_id=0, suggested_elemental_order=(), **kwarg):
+ if len(arg) and isinstance(arg[0], NestingTree):
+ super().__init__(*arg, **kwarg)
+ self._root_id = arg[0]._root_id
+ if suggested_elemental_order != ():
+ self._suggested_elemental_order = suggested_elemental_order
+ else:
+ self._suggested_elemental_order = arg[0]._suggested_elemental_order
+ else:
+ super().__init__(*arg, **kwarg)
+ self._root_id = root_id
+ self._suggested_elemental_order = suggested_elemental_order
+ if self._root_id not in self.nodes:
+ self.add_node(root_id, name="_root_", root=True)
+ self._clear_caches()
+
+ def __eq__(self, other):
+ return (
+ self._adj == other._adj
+ and self._node == other._node
+ and self._root_id == other._root_id
+ and self._suggested_elemental_order == other._suggested_elemental_order
+ )
+
+ def suggest_elemental_order(self, order):
+ self._suggested_elemental_order = tuple(j for j in order if j in self.nodes)
+
+ @property
+ def root_id(self):
+ """int : The code for the root node."""
+ return self._root_id
+
+ @root_id.setter
+ def root_id(self, x):
+ top_nests = list(self.successors(self._root_id))
+ top_attrs = [self.edges[self._root_id, t] for t in top_nests]
+ if self._root_id in self.nodes:
+ self.remove_node(self._root_id)
+ self._root_id = x
+ if self._root_id not in self.nodes:
+ self.add_node(self._root_id, name="_root_", root=True)
+ for t, a in zip(top_nests, top_attrs):
+ self.add_edge(self._root_id, t, **a, _clear_caches=False)
+ self._clear_caches()
+
+ def _clear_caches(self):
+ NestingTree.topological_sorted.invalidate(self, "topological_sorted")
+ NestingTree.topological_sorted_no_elementals.invalidate(
+ self, "topological_sorted_no_elementals"
+ )
+ NestingTree.standard_sort.invalidate(self, "standard_sort")
+ NestingTree.standard_sort.invalidate(self, "standard_slot_map")
+ NestingTree.elementals.invalidate(self, "elementals")
+ NestingTree.standard_competitive_edge_list.invalidate(
+ self, "standard_competitive_edge_list"
+ )
+ NestingTree.standard_competitive_edge_list_2.invalidate(
+ self, "standard_competitive_edge_list_2"
+ )
+ self._predecessor_slots = {}
+ self._successor_slots = {}
+ self.touch()
+
+ def add_edge(self, u, v, implied=False, _clear_caches=True, **kwarg):
+ """
+ Add an edge between u and v.
+
+ The nodes u and v will be automatically added if they are
+ not already in the graph.
+
+ Edge attributes can be specified with keywords.
+
+ Parameters
+ ----------
+ u, v : int
+ Nodes should be integer codes. The upstream node `u` is
+ a nest or the root node. Downsteam node `v` can be
+ a nest or elemental alternative.
+ implied : bool, default False
+ Implied edges are for connection of otherwise unconnected
+ nests to the root node.
+ _clear_caches : bool, default True
+ kwarg : keyword arguments, optional
+ Edge data (or labels or objects) can be assigned using
+ keyword arguments.
+ """
+ if not implied:
+ drops = []
+ for u_, v_, imp_ in self.in_edges(nbunch=[v], data="implied"):
+ if imp_:
+ drops.append([u_, v_])
+ for d in drops:
+ super().remove_edge(*d)
+ if _clear_caches:
+ self._clear_caches()
+ return super().add_edge(int(u), int(v), implied=implied, **kwarg)
+
+ def _remove_edge_no_implied(self, u, v, *arg, **kwarg):
+ result = super().remove_edge(u, v)
+ self._clear_caches()
+ return result
+
+ def remove_edge(self, u, v, *arg, **kwarg):
+ """
+ Remove the edge between u and v.
+
+ Parameters
+ ----------
+ u, v : int
+ Remove the edge between nodes u and v.
+
+ Raises
+ ------
+ NetworkXError
+ If there is not an edge between u and v.
+ """
+ result = super().remove_edge(u, v)
+ if self.in_degree(v) == 0 and v != self._root_id:
+ self.add_edge(self._root_id, v, implied=True, _clear_caches=False)
+ self._clear_caches()
+ return result
+
+ def add_node(
+ self,
+ code,
+ *,
+ children=(),
+ parent=None,
+ parents=None,
+ phi_parameters=None,
+ **kwarg,
+ ):
+ """
+ Add a single node `code` and update node attributes.
+
+ Parameters
+ ----------
+ code : int
+ Although the generic networkx.DiGraph allows a node
+ to be any hashable Python object except None, Larch
+ assumes that node codes are integers.
+ children : Collection
+ A collection of other node codes that are the children
+ of this new node. Links will be created from this node
+ to each child.
+ parent : int, optional
+ The parent of this new node. If not given, the root
+ node is assumed to be the parent of this node, and an
+ implied link is created. This implied link is removed
+ if the node is later made the child of some other node.
+ If the parent is set explicitly, the link is *not*
+ removed later.
+ parents : Collection, optional
+ Set multiple parent up-stream nodes.
+ phi_parameters : Mapping
+ Set phi parameters on graph links connecting to this
+ node, used in network GEV models. The keys of this mapping
+ indicate the node at the other end of the link, and the
+ values are parameter names.
+ kwarg : other keyword arguments, optional
+ Set or change node attributes using key=value.
+ """
+ if parents is not None and parent is not None:
+ raise TypeError("cannot give both parent and parents arguments")
+ super().add_node(code, **kwarg)
+ for child in children:
+ self.add_edge(code, child, _clear_caches=False)
+ if parent is not None:
+ self.add_edge(parent, code, _clear_caches=False)
+ elif parents is not None:
+ for p in parents:
+ self.add_edge(p, code, _clear_caches=False)
+ else:
+ if self.in_degree(code) == 0 and code != self._root_id:
+ self.add_edge(self._root_id, code, implied=True, _clear_caches=False)
+ if phi_parameters is not None:
+ for k, parametername in phi_parameters.items():
+ if (code, k) in self.edges:
+ self.edges[code, k]["parameter"] = str(parametername)
+ elif (k, code) in self.edges:
+ self.edges[k, code]["parameter"] = str(parametername)
+ else:
+ raise ValueError(
+ f"connected node {k} from phi_parameters not found"
+ )
+ self._clear_caches()
+
+ def new_node(self, *, code=None, **kwarg):
+ """
+ Add a new nesting node to this NestingTree.
+
+ A new unique code is automatically created and returned by
+ this method for creating new nests.
+
+ All arguments must be given as keyword parameters.
+
+ Parameters
+ ----------
+ parameter : str
+ The name of the parameter to associate with this nest.
+ children : Collection[int], optional
+ The code numbers for the children of this nest. These can be
+ elemental alternatives or other nests. If not given, no children
+ will be defined initially, but they can be added later.
+ parent : int, optional
+ The code number for the parent of this nest. If not given,
+ the parent is implied as the root node, unless and until set
+ to some other node.
+ name : str, optional
+ A human-readable name to associate with this nest.
+ code : int, optional
+ Use this code for the new nest. If this code already exists,
+ a ValueError is raised.
+
+ Returns
+ -------
+ int
+ The new code for this nest.
+
+ Raises
+ ------
+ ValueError
+ If a new code is given but it already exists in this tree.
+ """
+ if code is None:
+ proposed_code = len(self)
+ while proposed_code in self:
+ proposed_code += 1
+ else:
+ if code in self:
+ raise ValueError(f"code {code} already exists in this tree")
+ proposed_code = code
+ self.add_node(proposed_code, **kwarg)
+ return proposed_code
+
+ def add_nodes(self, codes, *arg, parent=None, **kwarg):
+ for code in codes:
+ self.add_node(code, *arg, parent=parent, **kwarg)
+
+ def remove_node(self, n):
+ """
+ Remove node n.
+
+ Removes the node n, reconnecting all outedges to the head node
+ of all inedges. Attempting to remove a non-existent node will
+ raise an exception.
+
+ Parameters
+ ----------
+ n : int
+ A node in the graph
+ """
+ replace_edges = {k: self.edges[k].copy() for k in self.edges(n)}
+ replace_heads = [k for k, _ in self.in_edges(n)]
+ super(NestingTree, self).remove_node(n)
+ for k, attrs in replace_edges.items():
+ for h in replace_heads:
+ super().add_edge(h, k[1], **attrs)
+ self._clear_caches()
+
+ @lazy
+ def topological_sorted(self):
+ return list(reverse_lexicographical_topological_sort(self))
+
+ @lazy
+ def topological_sorted_no_elementals(self):
+ try:
+ result = self.topological_sorted.copy()
+ except nx.NetworkXUnfeasible:
+ from networkx.algorithms.cycles import find_cycle
+
+ try:
+ cycle = find_cycle(self)
+ except:
+ pass
+ else:
+ print("Found graph cycle:")
+ print(list(cycle))
+ raise
+
+ # collect zero out-degree (elemental) codes in a set
+ # to remove them in a batch, which is much faster than
+ # removing them from the list one at a time.
+ to_remove = set()
+ for code, out_degree in self.out_degree:
+ if not out_degree:
+ to_remove.add(code)
+ return [i for i in result if i not in to_remove]
+
+ @lazy
+ def standard_sort(self):
+ return self.elementals + tuple(self.topological_sorted_no_elementals)
+
+ def node_name(self, code):
+ return self.nodes[code].get("name", str(code))
+
+ @property
+ def standard_sort_names(self):
+ return [self.node_name(s) for s in self.standard_sort]
+
+ def node_names(self):
+ return {s: (self.node_name(s) or s) for s in self.standard_sort}
+
+ def elemental_names(self):
+ return {s: (self.node_name(s) or s) for s in self.elementals}
+
+ @lazy
+ def standard_slot_map(self):
+ return {i: n for n, i in enumerate(self.standard_sort)}
+
+ def predecessor_slots(self, code):
+ if code in self._predecessor_slots:
+ return self._predecessor_slots[code]
+ s = np.empty(self.in_degree(code), dtype=np.int32)
+ for n, i in enumerate(self.predecessors(code)):
+ s[n] = self.standard_slot_map[i]
+ self._predecessor_slots[code] = s
+ return s
+
+ def successor_slots(self, code):
+ if code in self._successor_slots:
+ return self._successor_slots[code]
+ s = np.empty(self.out_degree(code), dtype=np.int32)
+ for n, i in enumerate(self.successors(code)):
+ s[n] = self.standard_slot_map[i]
+ self._successor_slots[code] = s
+ return s
+
+ def __elementals_iter(self):
+ for code, out_degree in self.out_degree:
+ if not out_degree:
+ yield code
+
+ @lazy
+ def elementals(self):
+ result = []
+ found = set()
+ for e in self._suggested_elemental_order:
+ if self.out_degree(e) == 0:
+ result.append(e)
+ found.add(e)
+ for e in sorted(self.__elementals_iter()):
+ if e not in found:
+ result.append(e)
+ found.add(e)
+ return tuple(result)
+
+ def n_elementals(self):
+ return len(self.elementals)
+
+ def n_intermediate_nests(self):
+ return len(self.nodes) - self.n_elementals() - 1
+
+ def elemental_descendants_iter(self, code):
+ if not self.out_degree(code):
+ yield code
+ return
+ all_d = nx.descendants(self, code)
+ for dcode, dout_degree in self.out_degree(all_d):
+ if not dout_degree:
+ yield dcode
+
+ def elemental_descendants(self, code):
+ return [i for i in self.elemental_descendants_iter(code)]
+
+ @property
+ def n_edges(self):
+ return self.number_of_edges()
+
+ def edge_slot_arrays(self, alpha_locator=None):
+ s = self.n_edges
+ up = np.zeros(s, dtype=np.int32)
+ dn = np.zeros(s, dtype=np.int32)
+ first_visit = np.zeros(s, dtype=np.int32)
+ alloc_slot = np.full_like(first_visit, -1)
+ n = s
+ first_visit_found = set()
+ for upcode in reversed(self.standard_sort):
+ upslot = self.standard_slot_map[upcode]
+ for dnslot in reversed(self.successor_slots(upcode)):
+ n -= 1
+ up[n] = upslot
+ dn[n] = dnslot
+ for n in range(s):
+ if dn[n] not in first_visit_found:
+ first_visit[n] = 1
+ first_visit_found.add(dn[n])
+ if alpha_locator is not None:
+ for n in range(s):
+ alloc_slot[n] = alpha_locator.get(
+ (self.standard_sort[up[n]], self.standard_sort[dn[n]]), -1
+ )
+ return up, dn, first_visit, alloc_slot
+
+ def nodes_with_successors_iter(self):
+ for code, out_degree in self.out_degree:
+ if out_degree:
+ yield code
+
+ def nodes_with_multiple_predecessors_iter(self):
+ for code, in_degree in self.in_degree:
+ if in_degree > 1:
+ yield code
+
+ @lazy
+ def standard_competitive_edge_list(self):
+ alphas = []
+ for n in self.nodes_with_multiple_predecessors_iter():
+ predecessors = sorted(self.predecessors(n))
+ for k in predecessors:
+ alphas.append((k, n))
+ return alphas
+
+ @lazy
+ def standard_competitive_edge_list_2(self):
+ alphas = []
+ for n in self.nodes_with_multiple_predecessors_iter():
+ predecessors = sorted(self.predecessors(n))
+ alphas.append((predecessors, n))
+ return alphas
+
+ def __getstate__(self):
+ attr = {} # self.__dict__.copy()
+ no_pickle = (
+ "topological_sorted",
+ "topological_sorted_no_elementals",
+ "standard_sort",
+ "standard_slot_map",
+ #'_standard_elemental_sort',
+ "elementals",
+ "_predecessor_slots",
+ "_successor_slots",
+ "_touch",
+ "node_dict_factory",
+ "_instance",
+ )
+ for k, v in self.__dict__.items():
+ if k not in no_pickle:
+ attr[k] = v
+ return attr
+
+ def __setstate__(self, state):
+ self.__dict__ = state.copy()
+ self._predecessor_slots = {}
+ self._successor_slots = {}
+
+ def __xml__(self, use_viz=True, use_dot=True, output="svg", figsize=None, **format):
+ viz = None
+ dot = None
+ if use_viz:
+ try:
+ import pygraphviz as viz
+ except ImportError:
+ if use_dot:
+ try:
+ import pydot as dot
+ except ImportError:
+ pass
+ elif use_dot:
+ try:
+ import pydot as dot
+ except ImportError:
+ pass
+
+ if viz is None and dot is None:
+ import warnings
+
+ if use_viz and use_dot:
+ msg = "neither pydot nor pygraphviz modules are installed, unable to draw nesting tree"
+ elif use_viz:
+ msg = "pygraphviz module not installed, unable to draw nesting tree"
+ elif use_dot:
+ msg = "pydot module not installed, unable to draw nesting tree"
+ else:
+ msg = "no drawing module used, unable to draw nesting tree"
+ warnings.warn(msg)
+ raise NotImplementedError(msg)
+
+ if viz is not None:
+ existing_format_keys = list(format.keys())
+ for key in existing_format_keys:
+ if key.upper() != key:
+ format[key.upper()] = format[key]
+ if "SUPPRESSGRAPHSIZE" not in format:
+ if "GRAPHWIDTH" not in format:
+ format["GRAPHWIDTH"] = 6.5
+ if "GRAPHHEIGHT" not in format:
+ format["GRAPHHEIGHT"] = 4
+ if "UNAVAILABLE" not in format:
+ format["UNAVAILABLE"] = True
+ # x = XML_Builder("div", {'class':"nesting_graph larch_art"})
+ # x.h2("Nesting Structure", anchor=1, attrib={'class':'larch_art_xhtml'})
+ from io import BytesIO
+
+ if "SUPPRESSGRAPHSIZE" not in format:
+ G = viz.AGraph(
+ name="Tree",
+ directed=True,
+ size="{GRAPHWIDTH},{GRAPHHEIGHT}".format(**format),
+ )
+ else:
+ G = viz.AGraph(name="Tree", directed=True)
+ for n in self.nodes:
+ nname = self.nodes[n].get("name", n)
+ if nname == n:
+ G.add_node(
+ n,
+ label="<{1}>".format(n, nname),
+ style="rounded,solid",
+ shape="box",
+ )
+ else:
+ G.add_node(
+ n,
+ label='<{1} ({0})>'.format(
+ n, nname
+ ),
+ style="rounded,solid",
+ shape="box",
+ )
+ eG = G.add_subgraph(
+ name="cluster_elemental",
+ nbunch=self.elementals,
+ color="#cccccc",
+ bgcolor="#eeeeee",
+ label="Elemental Alternatives",
+ labelloc="b",
+ style="rounded,solid",
+ )
+ unavailable_nodes = set()
+ # if format['UNAVAILABLE']:
+ # if self.is_provisioned():
+ # try:
+ # for n, ncode in enumerate(self.alternative_codes()):
+ # if np.sum(self.Data('Avail'),axis=0)[n,0]==0: unavailable_nodes.add(ncode)
+ # except: raise
+ # try:
+ # legible_avail = not isinstance(self.df.queries.avail, str)
+ # except:
+ # legible_avail = False
+ # if legible_avail:
+ # for ncode,navail in self.df.queries.avail.items():
+ # try:
+ # if navail=='0': unavailable_nodes.add(ncode)
+ # except: raise
+ # eG.add_subgraph(name='cluster_elemental_unavailable', nbunch=unavailable_nodes, color='#bbbbbb', bgcolor='#dddddd',
+ # label='Unavailable Alternatives', labelloc='b', style='rounded,solid')
+ G.add_node(self.root_id, label="Root")
+ up_nodes = set()
+ down_nodes = set()
+ for i, j in self.edges:
+ G.add_edge(i, j)
+ down_nodes.add(j)
+ up_nodes.add(i)
+ pyg_imgdata = BytesIO()
+ try:
+ G.draw(
+ pyg_imgdata, format=output, prog="dot"
+ ) # write postscript in k5.ps with neato layout
+ except ValueError as err:
+ if "in path" in str(err):
+ import warnings
+
+ warnings.warn(str(err) + "; unable to draw nesting tree in report")
+ raise NotImplementedError()
+ if output == "svg":
+ import xml.etree.ElementTree as ET
+
+ ET.register_namespace("", "http://www.w3.org/2000/svg")
+ ET.register_namespace("xlink", "http://www.w3.org/1999/xlink")
+ return ET.fromstring(pyg_imgdata.getvalue().decode())
+ else:
+ raise NotImplementedError(f"output {output} with use_viz")
+ else:
+
+ pydot = dot
+
+ # set Graphviz graph type
+ if self.is_directed():
+ graph_type = "digraph"
+ else:
+ graph_type = "graph"
+ strict = nx.number_of_selfloops(self) == 0 and not self.is_multigraph()
+
+ name = self.name
+ graph_defaults = self.graph.get("graph", {})
+ if name == "":
+ P = pydot.Dot(
+ "", graph_type=graph_type, strict=strict, **graph_defaults
+ )
+ else:
+ P = pydot.Dot(
+ '"%s"' % name,
+ graph_type=graph_type,
+ strict=strict,
+ **graph_defaults,
+ )
+ try:
+ P.set_node_defaults(**self.graph["node"])
+ except KeyError:
+ pass
+ try:
+ P.set_edge_defaults(**self.graph["edge"])
+ except KeyError:
+ pass
+
+ cluster_elemental = pydot.Cluster(
+ "elemental",
+ style="rounded",
+ bgcolor="lightgrey",
+ color="white",
+ rank="same",
+ rankdir="LR",
+ )
+
+ for n, nodedata in self.nodes(data=True):
+ str_nodedata = dict(
+ (k if k != "name" else "name_", '"' + str(v) + '"')
+ for k, v in nodedata.items()
+ )
+
+ if "parameter" in nodedata:
+ param_label = '
{0}'.format(
+ nodedata["parameter"]
+ )
+ else:
+ param_label = ""
+
+ if "name" in nodedata and n != self.root_id:
+ name = nodedata["name"]
+ str_nodedata["label"] = (
+ "<"
+ '({1}) '
+ "{0}"
+ "{2}>".format(name, n, param_label)
+ )
+
+ # Default styling for nodes can have been overridden
+ if n in self.elementals:
+ str_nodedata["style"] = str_nodedata.get("style", "filled")
+ str_nodedata["fillcolor"] = str_nodedata.get("fillcolor", "white")
+ elif n == self.root_id:
+ str_nodedata["shape"] = str_nodedata.get("shape", "invhouse")
+ else:
+ str_nodedata["style"] = str_nodedata.get("style", "rounded")
+ str_nodedata["shape"] = str_nodedata.get("shape", "rectangle")
+
+ p = pydot.Node(str(n), **str_nodedata)
+ P.add_node(p)
+ if n in self.elementals:
+ cluster_elemental.add_node(p)
+
+ P.add_subgraph(cluster_elemental)
+
+ if self.is_multigraph():
+ for u, v, key, edgedata in self.edges(data=True, keys=True):
+ str_edgedata = dict(
+ (k, str(v_)) for k, v_ in edgedata.items() if k != "key"
+ )
+ if v in self.elementals:
+ str_edgedata["constraint"] = "false"
+ edge = pydot.Edge(str(u), str(v), key=str(key), **str_edgedata)
+ P.add_edge(edge)
+
+ else:
+ for u, v, edgedata in self.edges(data=True):
+ str_edgedata = dict(
+ (k, '"' + str(v) + '"') for k, v in edgedata.items()
+ )
+ edge = pydot.Edge(str(u), str(v), **str_edgedata)
+ P.add_edge(edge)
+
+ ###
+ from xmle import Elem
+
+ prog = None
+ if output == "svg":
+ import xml.etree.ElementTree as ET
+
+ ET.register_namespace("", "http://www.w3.org/2000/svg")
+ ET.register_namespace("xlink", "http://www.w3.org/1999/xlink")
+ elif output == "png":
+ prog = [P.prog, "-Gdpi=300"]
+ if figsize is not None:
+ prog.append(f"-Gsize={figsize[0]},{figsize[1]}\!")
+ e = Elem.from_any(P.create(prog=prog, format=output, **format))
+ e.attrib["dpi"] = (300, 300)
+ return e
+ return Elem.from_any(P.create(prog=prog, format=output, **format))
+
+ def _repr_html_(self):
+ from xmle import Elem
+
+ x = Elem("div") << (self.__xml__())
+ return x.tostring()
+
+ def to_png(self, figsize=None, filename=None):
+ """
+ Output the graph visualization as a png.
+
+ Parameters
+ ----------
+ figsize : 2-tuple, optional
+ The (width, height) in inches.
+
+ Returns
+ -------
+ xmle.Elem
+ """
+ result = self.__xml__(output="png", use_viz=False, figsize=figsize)
+ if filename is not None:
+ import base64
+
+ if result.attrib["src"][:22] != "data:image/png;base64,":
+ raise ValueError(
+ "problem decoding png:{}".format(result.attrib["src"][:22])
+ )
+ with open(filename, "wb") as fh:
+ fh.write(base64.decodebytes(result.attrib["src"][22:].encode()))
+ return result
+
+ def partial_figure(
+ self, including_nodes=None, source=None, *, n=None, n_at_level=3, n_expand=1
+ ):
+ """
+ Generate a partial figure of the graph.
+
+ Parameters
+ ----------
+ including_nodes : iterable or None
+ An iterable containing node codes or names (or a mix).
+ source : nodecode, optional
+ All paths from this node to everything in `including_nodes` will be represented.
+ Defaults to `root_id`.
+ n : int
+ If including_nodes is None, select this number of nodes randomly
+
+ Returns
+ -------
+ Elem
+ """
+ if source is None:
+ source = self.root_id
+ from networkx.algorithms.simple_paths import all_simple_paths
+
+ shows = set()
+
+ if including_nodes is None and n is not None:
+ including_nodes = sorted(np.random.choice(self.nodes, n, replace=False))
+
+ if including_nodes is None and n_at_level is not None:
+ import itertools
+ from collections import deque
+
+ q = deque([self.root_id])
+ including_nodes = []
+ while q:
+ i = q.popleft()
+ take = tuple(itertools.islice(self.successors(i), n_at_level))
+ q.extend(take[:n_expand])
+ including_nodes.extend(take)
+
+ # Add every node in every path from the root to each `including_nodes`
+ for each_node in including_nodes:
+ if each_node in self.nodes:
+ for i in all_simple_paths(self, source, each_node):
+ for j in i:
+ shows.add(j)
+ else:
+ for each_node_ in self.get_nodes_by_name(each_node):
+ for i in all_simple_paths(self, source, each_node_):
+ for j in i:
+ shows.add(j)
+ s = self.subgraph(shows)
+ return graph_to_figure(s)
+
+ def get_nodes_by_name(self, name):
+ result = [k for k, v in self.nodes(data=True) if v.get("name") == name]
+ return result
+
+ def subgraph_from(self, node):
+ from collections import deque
+
+ Q = deque([node])
+ found = set()
+ while len(Q):
+ i = Q.popleft()
+ if i not in found:
+ found.add(i)
+ Q.extend(self.successors(i))
+ return NestingTree(self.subgraph(found), root_id=node)
+
+ def stats_summarize(self):
+ print("Graph Stats")
+ print(f" Overall: {len(self)} nodes")
+ tier = [self.root_id]
+ next_tier = list(self.successors(self.root_id))
+ n = 0
+ while len(next_tier):
+ tier = next_tier
+ n += 1
+ print(f" Tier {n}: {len(tier)} nodes")
+ next_tier = list()
+ for i in tier:
+ next_tier.extend(self.successors(i))
+
+ def node_slot_arrays(self, model, parameter_dict=None):
+ if hasattr(model, "get_slot_x"):
+ muslots = np.full([len(self)], -1, dtype=np.int32)
+ for child, childcode in enumerate(self.standard_sort):
+ # for parent in self.predecessor_slots(childcode):
+ # alpha[parent, child] = 1
+ pname = self.nodes[childcode].get("parameter", None)
+ muslots[child] = model.get_slot_x(pname)
+ else:
+ muslots = np.ones([len(self)], dtype=model)
+ for child, childcode in enumerate(self.standard_sort):
+ # for parent in self.predecessor_slots(childcode):
+ # alpha[parent, child] = 1
+ pname = self.nodes[childcode].get("parameter", None)
+ if pname is not None:
+ if parameter_dict is not None and isinstance(pname, str):
+ pname = parameter_dict.get(pname, pname)
+ muslots[child] = model(pname)
+ num = np.zeros(len(self.nodes), dtype=np.int32)
+ start = np.full(len(self.nodes), -1, dtype=np.int32)
+ n = self.n_edges
+ for upcode in reversed(self.standard_sort):
+ upslot = self.standard_slot_map[upcode]
+ for dnslot in reversed(self.successor_slots(upcode)):
+ n -= 1
+ num[upslot] += 1
+ start[upslot] = n
+ return (
+ muslots,
+ start,
+ num,
+ )
+
+ def _get_simple_mu_and_alpha(self, model, holdfast_invalidates=True):
+ # alpha = np.zeros([len(self), len(self)], dtype=np.float64)
+ mu = np.ones(
+ [
+ len(self),
+ ],
+ dtype=np.float64,
+ )
+ muslots = np.full(
+ [
+ len(self),
+ ],
+ -1,
+ dtype=np.int32,
+ )
+ for child, childcode in enumerate(self.standard_sort):
+ # for parent in self.predecessor_slots(childcode):
+ # alpha[parent, child] = 1
+ pname = self.nodes[childcode].get("parameter", None)
+ mu[child] = model.get_value(pname, default=1.0)
+ muslots[child] = model.get_slot_x(pname, holdfast_invalidates)
+
+ s = self.n_edges
+ up = np.zeros(s, dtype=np.int32)
+ dn = np.zeros(s, dtype=np.int32)
+ val = np.zeros(s, dtype=np.float64)
+ num = np.zeros(len(self.nodes), dtype=np.int32)
+ start = np.full(len(self.nodes), -1, dtype=np.int32)
+ # first_visit = np.zeros(s, dtype=np.int32)
+ n = s
+ # first_visit_found = set()
+ for upcode in reversed(self.standard_sort):
+ upslot = self.standard_slot_map[upcode]
+ for dnslot in reversed(self.successor_slots(upcode)):
+ n -= 1
+ up[n] = upslot
+ dn[n] = dnslot
+ num[upslot] += 1
+ start[upslot] = n
+ val[n] = 1 / len(
+ self.predecessor_slots(self.standard_sort[dnslot])
+ ) # TODO make not always constant fraction
+ # for n in range(s):
+ # if dn[n] not in first_visit_found:
+ # first_visit[n] = 1
+ # first_visit_found.add(dn[n])
+
+ return mu, muslots, up, dn, num, start, val
+
+ def as_arrays(self, model=np.float32, trim=True, parameter_dict=None):
+ """
+ Express this tree as a dict of arrays for use with sharrow.
+
+ Parameters
+ ----------
+ model : Model or dtype
+ Give a model to extract MU values as parameter slot positions,
+ or a dtype to extract as
+ trim : bool, default True
+ Trim the node slot arrays to be only for nests.
+ parameter_dict : Mapping[str,Number], optional
+ Maps named parameters to values.
+
+ Returns
+ -------
+ dict
+ """
+ result = {}
+ result["n_nodes"] = len(self)
+ result["n_alts"] = n_alts = self.n_elementals()
+ up, dn, first_visit, alloc_slot = self.edge_slot_arrays()
+ result["edges_up"] = up
+ result["edges_dn"] = dn
+ result["edges_1st"] = first_visit
+ result["edges_alloc"] = alloc_slot
+ muslots, start, num = self.node_slot_arrays(
+ model=model, parameter_dict=parameter_dict
+ )
+ if trim:
+ muslots = muslots[n_alts:]
+ start = start[n_alts:]
+ num = num[n_alts:]
+ result["mu_params"] = muslots
+ result["start_slots"] = start
+ result["len_slots"] = num
+ return result
+
+
+def graph_to_figure(graph, output_format="svg", **format):
+
+ try:
+ import pygraphviz as viz
+ except ImportError:
+ import warnings
+
+ warnings.warn("pygraphviz module not installed, unable to draw nesting tree")
+ raise NotImplementedError(
+ "pygraphviz module not installed, unable to draw nesting tree"
+ )
+ existing_format_keys = list(format.keys())
+ for key in existing_format_keys:
+ if key.upper() != key:
+ format[key.upper()] = format[key]
+ if "SUPPRESSGRAPHSIZE" not in format:
+ if "GRAPHWIDTH" not in format:
+ format["GRAPHWIDTH"] = 6.5
+ if "GRAPHHEIGHT" not in format:
+ format["GRAPHHEIGHT"] = 4
+ if "UNAVAILABLE" not in format:
+ format["UNAVAILABLE"] = True
+ # x = XML_Builder("div", {'class':"nesting_graph larch_art"})
+ # x.h2("Nesting Structure", anchor=1, attrib={'class':'larch_art_xhtml'})
+ from io import BytesIO
+
+ if "SUPPRESSGRAPHSIZE" not in format:
+ G = viz.AGraph(
+ name="Tree",
+ directed=True,
+ size="{GRAPHWIDTH},{GRAPHHEIGHT}".format(**format),
+ )
+ else:
+ G = viz.AGraph(name="Tree", directed=True)
+ for n in graph.nodes:
+ nname = graph.nodes[n].get("name", n)
+ if nname == n:
+ G.add_node(
+ n, label="<{1}>".format(n, nname), style="rounded,solid", shape="box"
+ )
+ else:
+ G.add_node(
+ n,
+ label='<{1} ({0})>'.format(n, nname),
+ style="rounded,solid",
+ shape="box",
+ )
+ try:
+ graph.elementals
+ except AttributeError:
+ pass
+ else:
+ eG = G.add_subgraph(
+ name="cluster_elemental",
+ nbunch=graph.elementals,
+ color="#cccccc",
+ bgcolor="#eeeeee",
+ label="Elemental Alternatives",
+ labelloc="b",
+ style="rounded,solid",
+ )
+ unavailable_nodes = set()
+ # if format['UNAVAILABLE']:
+ # if self.is_provisioned():
+ # try:
+ # for n, ncode in enumerate(self.alternative_codes()):
+ # if np.sum(self.Data('Avail'),axis=0)[n,0]==0: unavailable_nodes.add(ncode)
+ # except: raise
+ # try:
+ # legible_avail = not isinstance(self.df.queries.avail, str)
+ # except:
+ # legible_avail = False
+ # if legible_avail:
+ # for ncode,navail in self.df.queries.avail.items():
+ # try:
+ # if navail=='0': unavailable_nodes.add(ncode)
+ # except: raise
+ # eG.add_subgraph(name='cluster_elemental_unavailable', nbunch=unavailable_nodes, color='#bbbbbb', bgcolor='#dddddd',
+ # label='Unavailable Alternatives', labelloc='b', style='rounded,solid')
+ try:
+ G.add_node(graph.root_id, label="Root")
+ except AttributeError:
+ pass
+ up_nodes = set()
+ down_nodes = set()
+ for i, j in graph.edges:
+ G.add_edge(i, j)
+ down_nodes.add(j)
+ up_nodes.add(i)
+ pyg_imgdata = BytesIO()
+ try:
+ G.draw(
+ pyg_imgdata, format=output_format, prog="dot"
+ ) # write postscript in k5.ps with neato layout
+ except ValueError as err:
+ if "in path" in str(err):
+ import warnings
+
+ warnings.warn(str(err) + "; unable to draw nesting tree in report")
+ raise NotImplementedError()
+ from xmle import Elem
+
+ if output_format == "svg":
+ import xml.etree.ElementTree as ET
+
+ ET.register_namespace("", "http://www.w3.org/2000/svg")
+ ET.register_namespace("xlink", "http://www.w3.org/1999/xlink")
+ result = ET.fromstring(pyg_imgdata.getvalue().decode())
+ else:
+ result = Elem(
+ "span",
+ attrib={"style": "color:red"},
+ text=f"Unable to render output_format '{output_format}'",
+ )
+ x = Elem("div") << result
+ return x
def reverse_lexicographical_topological_sort(G, key=None):
- """
- Generator of nodes in reverse lexicographically topologically sorted order.
-
- A general topological sort is a nonunique permutation of the nodes such that
- an edge from u to v implies that u appears before v in the topological sort
- order.
-
- The lexicographical topological sort breaks ties by ordering according to
- node labels, so that the sorting becomes unique.
-
- Parameters
- ----------
- G : NetworkX digraph
- A directed acyclic graph (DAG)
-
- key : function, optional
- This function maps nodes to keys with which to resolve ambiguities in
- the sort order. Defaults to the identity function.
-
- Returns
- -------
- iterable
- An iterable of node names in lexicographical topological sort order.
-
- Raises
- ------
- NetworkXError
- Topological sort is defined for directed graphs only. If the graph `G`
- is undirected, a :exc:`NetworkXError` is raised.
-
- NetworkXUnfeasible
- If `G` is not a directed acyclic graph (DAG) no topological sort exists
- and a :exc:`NetworkXUnfeasible` exception is raised. This can also be
- raised if `G` is changed while the returned iterator is being processed
-
- RuntimeError
- If `G` is changed while the returned iterator is being processed.
-
- """
-
- if not G.is_directed():
- msg = "Topological sort not defined on undirected graphs."
- raise nx.NetworkXError(msg)
-
- if key is None:
- def key(node):
- return node
-
- nodeid_map = {n: i for i, n in enumerate(G)}
-
- def create_tuple(node):
- return key(node), nodeid_map[node], node
-
- outdegree_map = {v: d for v, d in G.out_degree() if d > 0}
- # These nodes have zero outdegree and ready to be returned.
- zero_outdegree = [create_tuple(v) for v, d in G.out_degree() if d == 0]
- heapq.heapify(zero_outdegree)
-
- while zero_outdegree:
- _, _, node = heapq.heappop(zero_outdegree)
-
- if node not in G:
- raise RuntimeError("Graph changed during iteration")
- for parent, child in G.in_edges(node):
- try:
- outdegree_map[parent] -= 1
- except KeyError as e:
- raise RuntimeError("Graph changed during iteration") from e
- if outdegree_map[parent] == 0:
- heapq.heappush(zero_outdegree, create_tuple(parent))
- del outdegree_map[parent]
-
- yield node
-
- if zero_outdegree:
- msg = "Graph contains a cycle or graph changed during iteration"
- raise nx.NetworkXUnfeasible(msg)
+ """
+ Generator of nodes in reverse lexicographically topologically sorted order.
+
+ A general topological sort is a nonunique permutation of the nodes such that
+ an edge from u to v implies that u appears before v in the topological sort
+ order.
+
+ The lexicographical topological sort breaks ties by ordering according to
+ node labels, so that the sorting becomes unique.
+
+ Parameters
+ ----------
+ G : NetworkX digraph
+ A directed acyclic graph (DAG)
+
+ key : function, optional
+ This function maps nodes to keys with which to resolve ambiguities in
+ the sort order. Defaults to the identity function.
+
+ Returns
+ -------
+ iterable
+ An iterable of node names in lexicographical topological sort order.
+
+ Raises
+ ------
+ NetworkXError
+ Topological sort is defined for directed graphs only. If the graph `G`
+ is undirected, a :exc:`NetworkXError` is raised.
+
+ NetworkXUnfeasible
+ If `G` is not a directed acyclic graph (DAG) no topological sort exists
+ and a :exc:`NetworkXUnfeasible` exception is raised. This can also be
+ raised if `G` is changed while the returned iterator is being processed
+
+ RuntimeError
+ If `G` is changed while the returned iterator is being processed.
+
+ """
+
+ if not G.is_directed():
+ msg = "Topological sort not defined on undirected graphs."
+ raise nx.NetworkXError(msg)
+
+ if key is None:
+
+ def key(node):
+ return node
+
+ nodeid_map = {n: i for i, n in enumerate(G)}
+
+ def create_tuple(node):
+ return key(node), nodeid_map[node], node
+
+ outdegree_map = {v: d for v, d in G.out_degree() if d > 0}
+ # These nodes have zero outdegree and ready to be returned.
+ zero_outdegree = [create_tuple(v) for v, d in G.out_degree() if d == 0]
+ heapq.heapify(zero_outdegree)
+
+ while zero_outdegree:
+ _, _, node = heapq.heappop(zero_outdegree)
+
+ if node not in G:
+ raise RuntimeError("Graph changed during iteration")
+ for parent, child in G.in_edges(node):
+ try:
+ outdegree_map[parent] -= 1
+ except KeyError as e:
+ raise RuntimeError("Graph changed during iteration") from e
+ if outdegree_map[parent] == 0:
+ heapq.heappush(zero_outdegree, create_tuple(parent))
+ del outdegree_map[parent]
+
+ yield node
+
+ if zero_outdegree:
+ msg = "Graph contains a cycle or graph changed during iteration"
+ raise nx.NetworkXUnfeasible(msg)
diff --git a/larch/numba/model.py b/larch/numba/model.py
index 7e3023b3..9f92d87b 100755
--- a/larch/numba/model.py
+++ b/larch/numba/model.py
@@ -91,6 +91,9 @@ def quantity_from_data_ca(
holdfast_arr, # float input shape=[n_params]
array_av, # int8 input shape=[n_alts]
array_ca, # float input shape=[n_alts, n_ca_vars]
+ array_ce_data, # float input shape=[n_casealts, n_ca_vars]
+ array_ce_indices, # int input shape=[n_casealts]
+ array_ce_ptr, # int input shape=[2]
utility_elem, # float output shape=[n_alts]
dutility_elem, # float output shape=[n_alts, n_params]
):
@@ -103,26 +106,19 @@ def quantity_from_data_ca(
scale_param_value = 1.0
scale_param_holdfast = 1
- for j in range(n_alts):
-
- # if self._array_ce_reversemap is not None:
- # if c >= self._array_ce_reversemap.shape[0] or j >= self._array_ce_reversemap.shape[1]:
- # row = -1
- # else:
- # row = self._array_ce_reversemap[c, j]
- row = -1
-
- if array_av[j]: # and row != -1:
-
+ if array_ce_data.shape[0] > 0:
+ j = 0
+ for row in range(array_ce_ptr[0], array_ce_ptr[1]):
+ while array_ce_indices[row] > j:
+ # skipped alts are unavail, i.e. have zero size
+ utility_elem[j] = 0
+ j += 1
if model_q_ca_param.shape[0]:
for i in range(model_q_ca_param.shape[0]):
- # if row >= 0:
- # _temp = self._array_ce[row, self.model_quantity_ca_data[i]]
- # else:
_temp = (
- array_ca[j, model_q_ca_data[i]]
- * model_q_ca_param_scale[i]
- * np.exp(parameter_arr[model_q_ca_param[i]])
+ array_ce_data[row, model_q_ca_data[i]]
+ * model_q_ca_param_scale[i]
+ * np.exp(parameter_arr[model_q_ca_param[i]])
)
utility_elem[j] += _temp
if not holdfast_arr[model_q_ca_param[i]]:
@@ -137,8 +133,44 @@ def quantity_from_data_ca(
if (model_q_scale_param[0] >= 0) and not scale_param_holdfast:
dutility_elem[j, model_q_scale_param[0]] += _tempsize
- else:
- utility_elem[j] = -np.inf
+ j += 1
+ while n_alts > j:
+ # skipped alts are unavail, i.e. have zero size
+ utility_elem[j] = 0
+ j += 1
+ else:
+
+ for j in range(n_alts):
+
+ row = -1
+
+ if array_av[j]: # and row != -1:
+
+ if model_q_ca_param.shape[0]:
+ for i in range(model_q_ca_param.shape[0]):
+ # if row >= 0:
+ # _temp = self._array_ce[row, self.model_quantity_ca_data[i]]
+ # else:
+ _temp = (
+ array_ca[j, model_q_ca_data[i]]
+ * model_q_ca_param_scale[i]
+ * np.exp(parameter_arr[model_q_ca_param[i]])
+ )
+ utility_elem[j] += _temp
+ if not holdfast_arr[model_q_ca_param[i]]:
+ dutility_elem[j, model_q_ca_param[i]] += _temp * scale_param_value
+
+ for i in range(model_q_ca_param.shape[0]):
+ if not holdfast_arr[model_q_ca_param[i]]:
+ dutility_elem[j, model_q_ca_param[i]] /= utility_elem[j]
+
+ _tempsize = np.log(utility_elem[j])
+ utility_elem[j] = _tempsize * scale_param_value
+ if (model_q_scale_param[0] >= 0) and not scale_param_holdfast:
+ dutility_elem[j, model_q_scale_param[0]] += _tempsize
+
+ else:
+ utility_elem[j] = -np.inf
@njit(error_model='numpy', fastmath=True, cache=True)
@@ -554,6 +586,9 @@ def _numba_master(
holdfast_arr, # float input shape=[n_params]
array_av, # int8 input shape=[n_nodes]
array_ca, # float input shape=[n_alts, n_ca_vars]
+ array_ce_data, # float input shape=[n_casealts, n_ca_vars]
+ array_ce_indices, # int input shape=[n_casealts]
+ array_ce_ptr, # int input shape=[2]
utility[:n_alts], # float output shape=[n_alts]
dutility[:n_alts],
)
@@ -1372,7 +1407,7 @@ def _loglike_runner(
else:
penalty = 0.0
- except:
+ except Exception as error:
shp = lambda y: getattr(y, 'shape', 'scalar')
dtp = lambda y: getattr(y, 'dtype', f'{type(y)} ')
import inspect
@@ -1392,7 +1427,8 @@ def _loglike_runner(
for n, (a, s) in enumerate(zip(self.work_arrays, out_sig_shapes), start=n+1):
s = s.rstrip(" ),")
print(f" {arg_names[n]:{arg_name_width}} [{n:2}] {s.strip():9}: {dtp(a)}{shp(a)}")
- raise
+ if not isinstance(error, RuntimeError):
+ raise
return result_arrays, penalty
@property
diff --git a/setup.py b/setup.py
index e98603f6..ed27b8db 100644
--- a/setup.py
+++ b/setup.py
@@ -140,7 +140,7 @@ def find_pyx(path='.'):
install_requires=[
'numpy >=1.13',
'scipy >=1.0',
- 'pandas >=0.24',
+ 'pandas >=0.24,<1.5',
'tables >=3.4',
'cloudpickle',
'tqdm',
diff --git a/tests/test_numba.py b/tests/test_numba.py
index c7c1e619..15d02876 100644
--- a/tests/test_numba.py
+++ b/tests/test_numba.py
@@ -1153,3 +1153,54 @@ def test_eville_mode_with_dataset():
m.datatree = tree
assert m.loglike() == approx(-8047.006193851376)
assert m.n_cases == 20739
+
+
+def test_eville_idce_quant():
+ from larch.numba import example, DataTree
+ hh, pp, tour, skims, emp = example(200, ['hh', 'pp', 'tour', 'skims', 'emp'])
+ tour = tour.drop(
+ columns=["TOURMODE", "TOURPURP", "N_STOPS", "N_TRIPS", "N_TRIPS_HBW", "N_TRIPS_HBO", "N_TRIPS_NHB"])
+ tour = tour.merge(hh[["HHID", "HOMETAZ"]], on="HHID")
+ observations = tour[["TOURID", "DTAZ", "HOMETAZ"]].copy()
+ observations.TOURID += 1
+ # turn idco tours into idca and attach distance and employment
+ distance = pd.DataFrame(
+ np.array(skims['AUTO_DIST'])
+ ).rename_axis(index='otaz').unstack().reset_index().rename(columns={"level_0": "dtaz", 0: "distance"})
+ distance.dtaz += 1
+ distance.otaz += 1
+ obs_ca = pd.merge(
+ emp[["TOTAL_EMP"]].reset_index(),
+ observations,
+ how="cross"
+ )
+ assert observations.shape[0] * emp.shape[0] == obs_ca.shape[0]
+ obs_ca = obs_ca.merge(
+ distance.rename(columns={"otaz": "HOMETAZ", "dtaz": "TAZ"}),
+ on=["HOMETAZ", "TAZ"],
+ how="left")
+ obs_ca["chosen"] = (obs_ca.DTAZ == obs_ca.TAZ).astype('int')
+ obs_ca = obs_ca.set_index(["TOURID", "TAZ"])
+ obs_ca['avail'] = 1
+ # Sample x% to produce idce data
+ frac_to_keep = 0.9
+ obs_idce = pd.concat([
+ obs_ca.loc[obs_ca.chosen == 0].sample(frac=frac_to_keep, replace=False, random_state=1),
+ obs_ca.loc[obs_ca.chosen == 1]
+ ]).sort_index()
+ tree = DataTree(
+ obs=Dataset.construct.from_idce(obs_idce, crack=False),
+ )
+ mx = NumbaModel(datatree=tree)
+ mx.choice_ca_var = "chosen"
+ mx.quantity_ca = P.emp_p * X('TOTAL_EMP')
+ mx.quantity_scale = P.Theta
+ mx.utility_ca = P.distance * X.distance
+ mx.availability_var = 'avail'
+ mx.lock_values(emp_p=0, Theta=1)
+ mx.set_cap(10)
+ ll_init = mx.loglike()
+ assert ll_init == approx(-75619.33743479446)
+ result = mx.maximize_loglike()
+ assert result.loglike == approx(-69408.93781754425)
+ assert result.x['distance'] == approx(-0.37974107678625235)