Skip to content

Commit

Permalink
Add ruff’s pyupgrade to pre-commit (#165)
Browse files Browse the repository at this point in the history
  • Loading branch information
nstarman committed Aug 23, 2023
1 parent eb90a41 commit 33def2d
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 30 deletions.
7 changes: 7 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@ repos:
- id: debug-statements
exclude: docs

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: "v0.0.282"
hooks:
- id: ruff
args:
- --fix

- repo: https://github.com/psf/black
rev: "23.7.0"
hooks:
Expand Down
36 changes: 14 additions & 22 deletions daft.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# -*- coding: utf-8 -*-

"""Code for Daft"""

__all__ = ["PGM", "Node", "Edge", "Plate"]
Expand All @@ -19,7 +17,7 @@
# pylint: disable=too-many-arguments, protected-access, unused-argument, too-many-lines


class PGM(object):
class PGM:
"""
The base object for building a graphical model representation.
Expand Down Expand Up @@ -504,7 +502,7 @@ def savefig(self, fname, *args, **kwargs):
self.figure.savefig(fname, *args, **kwargs)


class Node(object):
class Node:
"""
The representation of a random variable in a :class:`PGM`.
Expand Down Expand Up @@ -861,7 +859,7 @@ def get_frontier_coord(self, target_xy, ctx, edge):
raise ValueError("Wrong shape in object causes an error")


class Edge(object):
class Edge:
"""
An edge between two :class:`Node` objects.
Expand Down Expand Up @@ -1010,7 +1008,7 @@ def render(self, ctx):
return line


class Plate(object):
class Plate:
"""
A plate to encapsulate repeated independent processes in the model.
Expand Down Expand Up @@ -1133,7 +1131,7 @@ def render(self, ctx):
ha = "center"
else:
raise RuntimeError(
"Unknown positioning string: {0}".format(self.position)
f"Unknown positioning string: {self.position}"
)

if "bottom" in self.position:
Expand All @@ -1147,7 +1145,7 @@ def render(self, ctx):
va = "center"
else:
raise RuntimeError(
"Unknown positioning string: {0}".format(self.position)
f"Unknown positioning string: {self.position}"
)

ax.annotate(
Expand Down Expand Up @@ -1202,7 +1200,7 @@ def __init__(self, x, y, label, fontsize=None):
)


class _rendering_context(object):
class _rendering_context:
"""
:param shape:
The number of rows and columns in the grid.
Expand Down Expand Up @@ -1259,23 +1257,17 @@ def __init__(self, **kwargs):
# Make sure that the observed node style is one that we recognize.
self.observed_style = kwargs.get("observed_style", "shaded").lower()
styles = ["shaded", "inner", "outer"]
assert (
self.observed_style in styles
), "Unrecognized observed node style: {0}\n".format(
self.observed_style
) + "\tOptions are: {0}".format(
", ".join(styles)
assert self.observed_style in styles, (
f"Unrecognized observed node style: {self.observed_style}\n"
+ "\tOptions are: {}".format(", ".join(styles))
)

# Make sure that the alternate node style is one that we recognize.
self.alternate_style = kwargs.get("alternate_style", "inner").lower()
styles = ["shaded", "inner", "outer"]
assert (
self.alternate_style in styles
), "Unrecognized alternate node style: {0}\n".format(
self.alternate_style
) + "\tOptions are: {0}".format(
", ".join(styles)
assert self.alternate_style in styles, (
f"Unrecognized alternate node style: {self.alternate_style}\n"
+ "\tOptions are: {}".format(", ".join(styles))
)

# Set up the figure and grid dimensions.
Expand Down Expand Up @@ -1396,7 +1388,7 @@ def _pop_multiple(_dict, default, *args):

if len(results) > 1:
raise TypeError(
"The arguments ({0}) are equivalent, you can only provide one of them.".format(
"The arguments ({}) are equivalent, you can only provide one of them.".format(
", ".join([key for key, value in results])
)
)
Expand Down
2 changes: 0 additions & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# -*- coding: utf-8 -*-

from pkg_resources import get_distribution, DistributionNotFound

try:
Expand Down
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,9 @@ include = ["daft*"]

[tool.black]
line-length = 79


[tool.ruff]
target-version = "py39"
line-length = 79
select = ["UP"]
2 changes: 0 additions & 2 deletions test/test_daft.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# -*- coding: utf-8 -*-

import daft
import pytest

Expand Down
6 changes: 2 additions & 4 deletions test/test_examples.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# -*- coding: utf-8 -*-

import itertools

import daft
Expand Down Expand Up @@ -429,7 +427,7 @@ def recurse(pgm, nodename, level, c):
if level > 4:
return nodename
r = c // 2
r1nodename = "r{0:02d}{1:04d}".format(level, r)
r1nodename = f"r{level:02d}{r:04d}"
if 2 * r == c:
# print("adding {0}".format(r1nodename))
pgm.add_node(
Expand Down Expand Up @@ -457,7 +455,7 @@ def recurse(pgm, nodename, level, c):
pgm.add_edge("query", "input")

for c in range(16):
nodename = "map {0:02d}".format(c)
nodename = f"map {c:02d}"
pgm.add_node(nodename, str(nodename), c, 3.0, aspect=1.9)
pgm.add_edge("input", nodename)
level = 1
Expand Down

0 comments on commit 33def2d

Please sign in to comment.