Skip to content

Commit

Permalink
Merge pull request #127 from varunagrawal/fix/linter
Browse files Browse the repository at this point in the history
  • Loading branch information
dfm committed Apr 4, 2021
2 parents be2f7b0 + cdaeea6 commit 1c958ff
Showing 1 changed file with 40 additions and 29 deletions.
69 changes: 40 additions & 29 deletions daft.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# -*- coding: utf-8 -*-

"""Code for Daft"""

__all__ = ["PGM", "Node", "Edge", "Plate"]

from pkg_resources import get_distribution, DistributionNotFound
Expand All @@ -8,7 +10,7 @@
import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse
from matplotlib.patches import FancyArrow
from matplotlib.patches import Rectangle as Rectangle
from matplotlib.patches import Rectangle

import numpy as np

Expand All @@ -17,6 +19,8 @@
except DistributionNotFound:
pass

# pylint: disable=too-many-arguments, protected-access, unused-argument, too-many-lines


class PGM(object):
"""
Expand Down Expand Up @@ -78,7 +82,7 @@ def __init__(
node_ec="k",
directed=True,
aspect=1.0,
label_params={},
label_params=None,
dpi=None,
):
self._nodes = {}
Expand Down Expand Up @@ -127,9 +131,9 @@ def add_node(
observed=False,
fixed=False,
alternate=False,
offset=[0.0, 0.0],
offset=(0.0, 0.0),
fontsize=None,
plot_params={},
plot_params=None,
label_params=None,
shape="ellipse",
):
Expand Down Expand Up @@ -221,9 +225,9 @@ def add_edge(
xoffset=0.0,
yoffset=0.1,
label=None,
plot_params={},
label_params={},
**kwargs
plot_params=None,
label_params=None,
**kwargs # pylint: disable=unused-argument
):
"""
Construct an :class:`Edge` between two named :class:`Node` objects.
Expand Down Expand Up @@ -280,7 +284,7 @@ def add_plate(
self,
plate,
label=None,
label_offset=[5, 5],
label_offset=(5, 5),
shift=0,
position="bottom left",
fontsize=None,
Expand Down Expand Up @@ -333,8 +337,6 @@ def add_plate(

self._plates.append(_plate)

return None

def add_text(self, x, y, label, fontsize=None):
"""
A subclass of plate to writing text using grid coordinates. Any
Expand Down Expand Up @@ -456,13 +458,15 @@ def get_min(minsize, artist):

@property
def figure(self):
"""Figure as a property."""
return self._ctx.figure()

@property
def ax(self):
"""Axes as a property."""
return self._ctx.ax()

def show(self, dpi=None, *args, **kwargs):
def show(self, *args, dpi=None, **kwargs):
"""
Wrapper on :class:`PGM.render()` that calls `matplotlib.show()`
immediately after.
Expand Down Expand Up @@ -563,9 +567,9 @@ def __init__(
observed=False,
fixed=False,
alternate=False,
offset=[0.0, 0.0],
offset=(0.0, 0.0),
fontsize=None,
plot_params={},
plot_params=None,
label_params=None,
shape="ellipse",
):
Expand Down Expand Up @@ -600,20 +604,14 @@ def __init__(
self.aspect = aspect

# Set fontsize
if fontsize is not None:
self.fontsize = fontsize
else:
self.fontsize = mpl.rcParams["font.size"]
self.fontsize = fontsize if fontsize else mpl.rcParams["font.size"]

# Display parameters.
self.plot_params = dict(plot_params)
self.plot_params = dict(plot_params) if plot_params else {}

# Text parameters.
self.offset = list(offset)
if label_params is not None:
self.label_params = dict(label_params)
else:
self.label_params = None
self.label_params = dict(label_params) if label_params else None

# Shape
if shape in ["ellipse", "rectangle"]:
Expand Down Expand Up @@ -856,7 +854,7 @@ def get_frontier_coord(self, target_xy, ctx, edge):

else:
# Should never append
raise (ValueError("Wrong shape in object causes an error"))
raise ValueError("Wrong shape in object causes an error")


class Edge(object):
Expand Down Expand Up @@ -904,17 +902,17 @@ def __init__(
label=None,
xoffset=0,
yoffset=0.1,
plot_params={},
label_params={},
plot_params=None,
label_params=None,
):
self.node1 = node1
self.node2 = node2
self.directed = directed
self.label = label
self.xoffset = xoffset
self.yoffset = yoffset
self.plot_params = dict(plot_params)
self.label_params = dict(label_params)
self.plot_params = dict(plot_params) if plot_params else {}
self.label_params = dict(label_params) if label_params else {}

def _get_coords(self, ctx):
"""
Expand Down Expand Up @@ -1050,7 +1048,7 @@ def __init__(
self,
rect,
label=None,
label_offset=[5, 5],
label_offset=(5, 5),
shift=0,
position="bottom left",
fontsize=None,
Expand Down Expand Up @@ -1283,7 +1281,7 @@ def __init__(self, **kwargs):
self.node_ec = kwargs.get("node_ec", "k")
self.directed = kwargs.get("directed", True)
self.aspect = kwargs.get("aspect", 1.0)
self.label_params = dict(kwargs.get("label_params", {}))
self.label_params = dict(kwargs.get("label_params", {}) or {})

self.dpi = kwargs.get("dpi", None)

Expand All @@ -1292,12 +1290,14 @@ def __init__(self, **kwargs):
self._ax = None

def reset_shape(self, shape, adj_origin=False):
"""Reset the shape and figure size."""
# shape is scaled by grid_unit
# so divide by grid_unit for proper shape
self.shape = shape / self.grid_unit + self.padding
self.figsize = self.grid_unit * self.shape / self.shp_fig_scale

def reset_origin(self, origin, adj_shape=False):
"""Reset the origin."""
# origin is scaled by grid_unit
# so divide by grid_unit for proper shape
self.origin = origin / self.grid_unit - self.padding
Expand All @@ -1306,15 +1306,18 @@ def reset_origin(self, origin, adj_shape=False):
self.figsize = self.grid_unit * self.shape / self.shp_fig_scale

def reset_figure(self):
"""Reset the figure."""
self.close()

def close(self):
"""Close the figure if it is set up."""
if self._figure is not None:
plt.close(self._figure)
self._figure = None
self._ax = None

def figure(self):
"""Return the current figure else create a new one."""
if self._figure is not None:
return self._figure
args = {"figsize": self.figsize}
Expand All @@ -1324,6 +1327,7 @@ def figure(self):
return self._figure

def ax(self):
"""Return the current axes else create a new one."""
if self._ax is not None:
return self._ax

Expand Down Expand Up @@ -1392,6 +1396,13 @@ def _pop_multiple(_dict, default, *args):


class SameLocationError(Exception):
"""
Exception to notify if two nodes are in the same position in the plot.
:param edge:
The Edge object whose nodes are being added.
"""

def __init__(self, edge):
self.message = (
"Attempted to add edge between `{}` and `{}` but they "
Expand Down

0 comments on commit 1c958ff

Please sign in to comment.