Skip to content

Commit

Permalink
src layout (#166)
Browse files Browse the repository at this point in the history
Breaks `daft.py` into submodules.

Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Aug 25, 2023
1 parent 33def2d commit 62ea96c
Show file tree
Hide file tree
Showing 5 changed files with 241 additions and 221 deletions.
3 changes: 0 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,6 @@ Homepage = "http://daft-pgm.org"
Documentation = "http://docs.daft-pgm.org"
Repository = "https://github.com/daft-dev/daft"

[tool.setuptools.packages.find]
include = ["daft*"]


[tool.setuptools_scm]

Expand Down
14 changes: 14 additions & 0 deletions src/daft/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""Code for Daft"""

from importlib.metadata import version as get_distribution

from . import _core, _exceptions, _utils
from ._core import PGM, Node, Edge, Plate, Text
from ._exceptions import SameLocationError
from ._utils import _rendering_context, _pop_multiple

__version__ = get_distribution("daft")
__all__ = []
__all__ += _core.__all__
__all__ += _exceptions.__all__
__all__ += _utils.__all__
221 changes: 3 additions & 218 deletions daft.py → src/daft/_core.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
"""Code for Daft"""

__all__ = ["PGM", "Node", "Edge", "Plate"]

from importlib.metadata import version as get_distribution
# TODO: should Text be added?

import matplotlib as mpl
import matplotlib.pyplot as plt
Expand All @@ -12,7 +11,8 @@

import numpy as np

__version__ = get_distribution("daft")
from ._exceptions import SameLocationError
from ._utils import _rendering_context, _pop_multiple

# pylint: disable=too-many-arguments, protected-access, unused-argument, too-many-lines

Expand Down Expand Up @@ -1198,218 +1198,3 @@ def __init__(self, x, y, label, fontsize=None):
rect_params=self.rect_params,
bbox=self.bbox,
)


class _rendering_context:
"""
:param shape:
The number of rows and columns in the grid.
:param origin:
The coordinates of the bottom left corner of the plot.
:param grid_unit:
The size of the grid spacing measured in centimeters.
:param node_unit:
The base unit for the node size. This is a number in centimeters that
sets the default diameter of the nodes.
:param observed_style:
How should the "observed" nodes be indicated? This must be one of:
``"shaded"``, ``"inner"`` or ``"outer"`` where ``inner`` and
``outer`` nodes are shown as double circles with the second circle
plotted inside or outside of the standard one, respectively.
:param alternate_style: (optional)
How should the "alternate" nodes be indicated? This must be one of:
``"shaded"``, ``"inner"`` or ``"outer"`` where ``inner`` and
``outer`` nodes are shown as double circles with the second circle
plotted inside or outside of the standard one, respectively.
:param node_ec:
The default edge color for the nodes.
:param node_fc:
The default face color for the nodes.
:param plate_fc:
The default face color for plates.
:param directed:
Should the edges be directed by default?
:param aspect:
The default aspect ratio for the nodes.
:param label_params:
Default node label parameters.
:param dpi: (optional)
The DPI value to use for rendering.
"""

def __init__(self, **kwargs):
# Save the style defaults.
self.line_width = kwargs.get("line_width", 1.0)

# Make sure that the observed node style is one that we recognize.
self.observed_style = kwargs.get("observed_style", "shaded").lower()
styles = ["shaded", "inner", "outer"]
assert self.observed_style in styles, (
f"Unrecognized observed node style: {self.observed_style}\n"
+ "\tOptions are: {}".format(", ".join(styles))
)

# Make sure that the alternate node style is one that we recognize.
self.alternate_style = kwargs.get("alternate_style", "inner").lower()
styles = ["shaded", "inner", "outer"]
assert self.alternate_style in styles, (
f"Unrecognized alternate node style: {self.alternate_style}\n"
+ "\tOptions are: {}".format(", ".join(styles))
)

# Set up the figure and grid dimensions.
self.padding = 0.1
self.shp_fig_scale = 2.54

self.shape = np.array(kwargs.get("shape", [1, 1]), dtype=np.float64)
self.origin = np.array(kwargs.get("origin", [0, 0]), dtype=np.float64)
self.grid_unit = kwargs.get("grid_unit", 2.0)
self.figsize = self.grid_unit * self.shape / self.shp_fig_scale

self.node_unit = kwargs.get("node_unit", 1.0)
self.node_ec = kwargs.get("node_ec", "k")
self.node_fc = kwargs.get("node_fc", "w")
self.plate_fc = kwargs.get("plate_fc", "w")
self.directed = kwargs.get("directed", True)
self.aspect = kwargs.get("aspect", 1.0)
self.label_params = dict(kwargs.get("label_params", {}) or {})

self.dpi = kwargs.get("dpi", None)

# Initialize the figure to ``None`` to handle caching later.
self._figure = None
self._ax = None

def reset_shape(self, shape, adj_origin=False):
"""Reset the shape and figure size."""
# shape is scaled by grid_unit
# so divide by grid_unit for proper shape
self.shape = shape / self.grid_unit + self.padding
self.figsize = self.grid_unit * self.shape / self.shp_fig_scale

def reset_origin(self, origin, adj_shape=False):
"""Reset the origin."""
# origin is scaled by grid_unit
# so divide by grid_unit for proper shape
self.origin = origin / self.grid_unit - self.padding
if adj_shape:
self.shape -= self.origin
self.figsize = self.grid_unit * self.shape / self.shp_fig_scale

def reset_figure(self):
"""Reset the figure."""
self.close()

def close(self):
"""Close the figure if it is set up."""
if self._figure is not None:
plt.close(self._figure)
self._figure = None
self._ax = None

def figure(self):
"""Return the current figure else create a new one."""
if self._figure is not None:
return self._figure
args = {"figsize": self.figsize}
if self.dpi is not None:
args["dpi"] = self.dpi
self._figure = plt.figure(**args)
return self._figure

def ax(self):
"""Return the current axes else create a new one."""
if self._ax is not None:
return self._ax

# Add a new axis object if it doesn't exist.
self._ax = self.figure().add_axes(
(0, 0, 1, 1), frameon=False, xticks=[], yticks=[]
)

# Set the bounds.
l0 = self.convert(*self.origin)
l1 = self.convert(*(self.origin + self.shape))
self._ax.set_xlim(l0[0], l1[0])
self._ax.set_ylim(l0[1], l1[1])

return self._ax

def convert(self, *xy):
"""
Convert from model coordinates to plot coordinates.
"""
assert len(xy) == 2
return self.grid_unit * (np.atleast_1d(xy) - self.origin)


def _pop_multiple(_dict, default, *args):
"""
A helper function for dealing with the way that matplotlib annoyingly
allows multiple keyword arguments. For example, ``edgecolor`` and ``ec``
are generally equivalent but no exception is thrown if they are both
used.
*Note: This function does throw a :class:`TypeError` if more than one
of the equivalent arguments are provided.*
:param _dict:
A :class:`dict`-like object to "pop" from.
:param default:
The default value to return if none of the arguments are provided.
:param *args:
The arguments to try to retrieve.
"""
assert len(args) > 0, "You must provide at least one argument to `pop()`."

results = []
for arg in args:
try:
results.append((arg, _dict.pop(arg)))
except KeyError:
pass

if len(results) > 1:
raise TypeError(
"The arguments ({}) are equivalent, you can only provide one of them.".format(
", ".join([key for key, value in results])
)
)

if len(results) == 0:
return default

return results[0][1]


class SameLocationError(Exception):
"""
Exception to notify if two nodes are in the same position in the plot.
:param edge:
The Edge object whose nodes are being added.
"""

def __init__(self, edge):
self.message = (
"Attempted to add edge between `{}` and `{}` but they "
+ "share the same location."
).format(edge.node1.name, edge.node2.name)
super().__init__(self.message)
19 changes: 19 additions & 0 deletions src/daft/_exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""Daft errors"""

__all__: list[str] = []


class SameLocationError(Exception):
"""
Exception to notify if two nodes are in the same position in the plot.
:param edge:
The Edge object whose nodes are being added.
"""

def __init__(self, edge):
self.message = (
"Attempted to add edge between `{}` and `{}` but they "
+ "share the same location."
).format(edge.node1.name, edge.node2.name)
super().__init__(self.message)

0 comments on commit 62ea96c

Please sign in to comment.