Skip to content

Commit

Permalink
Update pre-commit config (#1842)
Browse files Browse the repository at this point in the history
* Upgrade black style to 21.11b1

* Fix bad runtime error string formatting.

This also caused flake to complain which it is now no more.

* Fix outstanding pre-commit issues

* Upgrade isort version in pre-commit hooks

* Udate Lucas-C pre-commit hook version

* Update jumanjihouse pre-commit version

* Update pre-commit hooks

* Escape isort args

* Do not skip flake8 in pre-commit

* Fix outstanding precommit issues; ignore E741

* Apply isort

* Avoid isort 5 warning by double-dashing``

* Fix black formatting in recent commit.

* Pin flake8 and flake8-print versions, skip flak8 for pre-commit run
  • Loading branch information
Balandat committed Dec 2, 2021
1 parent df018d0 commit 7289e7d
Show file tree
Hide file tree
Showing 46 changed files with 366 additions and 183 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/linting.yml
Expand Up @@ -21,12 +21,13 @@ jobs:
python-version: "3.6"
- name: Install dependencies
run: |
pip install flake8==3.7.9 flake8-print==3.1.4 pre-commit
pip install flake8==4.0.1 flake8-print==4.0.0 pre-commit
pre-commit install
pre-commit run seed-isort-config || true
- name: Run linting
run: |
flake8
- name: Run pre-commit checks
# skipping flake8 here (run separatey above b/c pre-commit does not include flake8-print)
run: |
SKIP=flake8 pre-commit run --files test/**/*.py gpytorch/**/*.py
21 changes: 12 additions & 9 deletions .pre-commit-config.yaml
@@ -1,10 +1,7 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.4.0
rev: v4.0.1
hooks:
- id: flake8
args: [--config=setup.cfg]
exclude: ^(examples/*)|(docs/*)
- id: check-byte-order-marker
- id: check-case-conflict
- id: check-merge-conflict
Expand All @@ -14,29 +11,35 @@ repos:
args: [--fix=lf]
- id: trailing-whitespace
- id: debug-statements
- repo: https://github.com/pycqa/flake8
rev: 4.0.1
hooks:
- id: flake8
args: [--config=setup.cfg]
exclude: ^(examples/*)|(docs/*)
- repo: https://github.com/ambv/black
rev: 19.10b0
rev: 21.11b1
hooks:
- id: black
exclude: ^(build/*)|(docs/*)|(examples/*)
args: [-l 120, --target-version=py36]
- repo: https://github.com/pre-commit/mirrors-isort
rev: v4.3.21
rev: v5.9.3
hooks:
- id: isort
language_version: python3
exclude: ^(build/*)|(docs/*)|(examples/*)
args: [-w 120, -m 3, -tc, --project=gpytorch]
args: [-w120, -m3, --tc, --project=gpytorch]
- repo: https://github.com/jumanjihouse/pre-commit-hooks
rev: 1.11.0
rev: 2.1.5
hooks:
- id: require-ascii
exclude: ^(examples/LBFGS.py)|(examples/.*\.ipynb)
- id: script-must-have-extension
- id: forbid-binary
exclude: ^(examples/*)
- repo: https://github.com/Lucas-C/pre-commit-hooks
rev: v1.1.7
rev: v1.1.10
hooks:
- id: forbid-crlf
- id: forbid-tabs
2 changes: 1 addition & 1 deletion docs/Makefile
Expand Up @@ -17,4 +17,4 @@ help:
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
1 change: 0 additions & 1 deletion docs/source/module.rst
Expand Up @@ -9,4 +9,3 @@ gpytorch.Module

.. autoclass:: gpytorch.Module
:members:

5 changes: 4 additions & 1 deletion gpytorch/kernels/additive_structure_kernel.py
Expand Up @@ -42,7 +42,10 @@ def is_stationary(self) -> bool:
return self.base_kernel.is_stationary

def __init__(
self, base_kernel: Kernel, num_dims: int, active_dims: Optional[Tuple[int, ...]] = None,
self,
base_kernel: Kernel,
num_dims: int,
active_dims: Optional[Tuple[int, ...]] = None,
):
super(AdditiveStructureKernel, self).__init__(active_dims=active_dims)
self.base_kernel = base_kernel
Expand Down
16 changes: 12 additions & 4 deletions gpytorch/kernels/arc_kernel.py
Expand Up @@ -119,26 +119,34 @@ def __init__(
# TODO: check the errors given by interval
angle_constraint = Interval(0.1, 0.9)
self.register_parameter(
name="raw_angle", parameter=torch.nn.Parameter(torch.zeros(*self.batch_shape, 1, self.last_dim)),
name="raw_angle",
parameter=torch.nn.Parameter(torch.zeros(*self.batch_shape, 1, self.last_dim)),
)
if angle_prior is not None:
if not isinstance(angle_prior, Prior):
raise TypeError("Expected gpytorch.priors.Prior but got " + type(angle_prior).__name__)
self.register_prior(
"angle_prior", angle_prior, lambda m: m.angle, lambda m, v: m._set_angle(v),
"angle_prior",
angle_prior,
lambda m: m.angle,
lambda m, v: m._set_angle(v),
)

self.register_constraint("raw_angle", angle_constraint)

self.register_parameter(
name="raw_radius", parameter=torch.nn.Parameter(torch.zeros(*self.batch_shape, 1, self.last_dim)),
name="raw_radius",
parameter=torch.nn.Parameter(torch.zeros(*self.batch_shape, 1, self.last_dim)),
)

if radius_prior is not None:
if not isinstance(radius_prior, Prior):
raise TypeError("Expected gpytorch.priors.Prior but got " + type(radius_prior).__name__)
self.register_prior(
"radius_prior", radius_prior, lambda m: m.radius, lambda m, v: m._set_radius(v),
"radius_prior",
radius_prior,
lambda m: m.radius,
lambda m, v: m._set_radius(v),
)

radius_constraint = Positive()
Expand Down
4 changes: 3 additions & 1 deletion gpytorch/kernels/distributional_input_kernel.py
Expand Up @@ -25,7 +25,9 @@ class DistributionalInputKernel(Kernel):
has_lengthscale = True

def __init__(
self, distance_function: Callable, **kwargs,
self,
distance_function: Callable,
**kwargs,
):
super(DistributionalInputKernel, self).__init__(**kwargs)
if distance_function is None:
Expand Down
14 changes: 11 additions & 3 deletions gpytorch/kernels/grid_interpolation_kernel.py
Expand Up @@ -114,7 +114,10 @@ def __init__(
grid = create_grid(self.grid_sizes, self.grid_bounds)

super(GridInterpolationKernel, self).__init__(
base_kernel=base_kernel, grid=grid, interpolation_mode=True, active_dims=active_dims,
base_kernel=base_kernel,
grid=grid,
interpolation_mode=True,
active_dims=active_dims,
)
self.register_buffer("has_initialized_grid", torch.tensor(has_initialized_grid, dtype=torch.bool))

Expand Down Expand Up @@ -170,7 +173,10 @@ def forward(self, x1, x2, diag=False, last_dim_is_batch=False, **params):
for x_min, x_max, spacing in zip(x_mins, x_maxs, grid_spacings)
)
grid = create_grid(
self.grid_sizes, self.grid_bounds, dtype=self.grid[0].dtype, device=self.grid[0].device,
self.grid_sizes,
self.grid_bounds,
dtype=self.grid[0].dtype,
device=self.grid[0].device,
)
self.update_grid(grid)

Expand All @@ -186,7 +192,9 @@ def forward(self, x1, x2, diag=False, last_dim_is_batch=False, **params):
right_interp_indices, right_interp_values = self._compute_grid(x2, last_dim_is_batch)

batch_shape = _mul_broadcast_shape(
base_lazy_tsr.batch_shape, left_interp_indices.shape[:-2], right_interp_indices.shape[:-2],
base_lazy_tsr.batch_shape,
left_interp_indices.shape[:-2],
right_interp_indices.shape[:-2],
)
res = InterpolatedLazyTensor(
base_lazy_tsr.expand(*batch_shape, *base_lazy_tsr.matrix_shape),
Expand Down
6 changes: 5 additions & 1 deletion gpytorch/kernels/polynomial_kernel.py
Expand Up @@ -35,7 +35,11 @@ class PolynomialKernel(Kernel):
"""

def __init__(
self, power: int, offset_prior: Optional[Prior] = None, offset_constraint: Optional[Interval] = None, **kwargs,
self,
power: int,
offset_prior: Optional[Prior] = None,
offset_constraint: Optional[Interval] = None,
**kwargs,
):
super().__init__(**kwargs)
if offset_constraint is None:
Expand Down
5 changes: 4 additions & 1 deletion gpytorch/kernels/product_structure_kernel.py
Expand Up @@ -48,7 +48,10 @@ def is_stationary(self) -> bool:
return self.base_kernel.is_stationary

def __init__(
self, base_kernel: Kernel, num_dims: int, active_dims: Optional[Tuple[int, ...]] = None,
self,
base_kernel: Kernel,
num_dims: int,
active_dims: Optional[Tuple[int, ...]] = None,
):
super(ProductStructureKernel, self).__init__(active_dims=active_dims)
self.base_kernel = base_kernel
Expand Down
8 changes: 4 additions & 4 deletions gpytorch/lazy/keops_lazy_tensor.py
Expand Up @@ -57,8 +57,8 @@ def _getitem(self, row_index, col_index, *batch_indices):
elif isinstance(batch_indices, tuple):
if any(not isinstance(bi, slice) for bi in batch_indices):
raise RuntimeError(
f"Attempting to tensor index a non-batch matrix's batch dimensions. "
"Got batch index {batch_indices} but my shape was {self.shape}"
"Attempting to tensor index a non-batch matrix's batch dimensions. "
f"Got batch index {batch_indices} but my shape was {self.shape}"
)
x1 = x1.expand(*([1] * len(batch_indices)), *self.x1.shape[-2:])
x1 = x1[(*batch_indices, row_index, dim_index)]
Expand All @@ -74,8 +74,8 @@ def _getitem(self, row_index, col_index, *batch_indices):
elif isinstance(batch_indices, tuple):
if any([not isinstance(bi, slice) for bi in batch_indices]):
raise RuntimeError(
f"Attempting to tensor index a non-batch matrix's batch dimensions. "
"Got batch index {batch_indices} but my shape was {self.shape}"
"Attempting to tensor index a non-batch matrix's batch dimensions. "
f"Got batch index {batch_indices} but my shape was {self.shape}"
)
x2 = x2.expand(*([1] * len(batch_indices)), *self.x2.shape[-2:])
x2 = x2[(*batch_indices, row_index, dim_index)]
Expand Down
54 changes: 47 additions & 7 deletions gpytorch/lazy/lazy_evaluated_kernel_tensor.py
Expand Up @@ -135,7 +135,13 @@ def _getitem(self, row_index, col_index, *batch_indices):
new_kernel = self.kernel.__getitem__(batch_indices)

# Now construct a kernel with those indices
return self.__class__(x1, x2, kernel=new_kernel, last_dim_is_batch=self.last_dim_is_batch, **self.params,)
return self.__class__(
x1,
x2,
kernel=new_kernel,
last_dim_is_batch=self.last_dim_is_batch,
**self.params,
)

def _matmul(self, rhs):
# This _matmul is defined computes the kernel in chunks
Expand All @@ -156,7 +162,13 @@ def _matmul(self, rhs):
res = []
for sub_x1 in sub_x1s:
sub_kernel_matrix = lazify(
self.kernel(sub_x1, x2, diag=False, last_dim_is_batch=self.last_dim_is_batch, **self.params,)
self.kernel(
sub_x1,
x2,
diag=False,
last_dim_is_batch=self.last_dim_is_batch,
**self.params,
)
)
res.append(sub_kernel_matrix._matmul(rhs))

Expand Down Expand Up @@ -185,7 +197,13 @@ def _quad_form_derivative(self, left_vecs, right_vecs):
sub_x1.requires_grad_(True)
with torch.enable_grad(), settings.lazily_evaluate_kernels(False):
sub_kernel_matrix = lazify(
self.kernel(sub_x1, x2, diag=False, last_dim_is_batch=self.last_dim_is_batch, **self.params,)
self.kernel(
sub_x1,
x2,
diag=False,
last_dim_is_batch=self.last_dim_is_batch,
**self.params,
)
)
sub_grad_outputs = tuple(sub_kernel_matrix._quad_form_derivative(sub_left_vecs, right_vecs))
sub_kernel_outputs = tuple(sub_kernel_matrix.representation())
Expand Down Expand Up @@ -238,7 +256,11 @@ def _size(self):

def _transpose_nonbatch(self):
return self.__class__(
self.x2, self.x1, kernel=self.kernel, last_dim_is_batch=self.last_dim_is_batch, **self.params,
self.x2,
self.x1,
kernel=self.kernel,
last_dim_is_batch=self.last_dim_is_batch,
**self.params,
)

def add_jitter(self, jitter_val=1e-3):
Expand All @@ -247,7 +269,13 @@ def add_jitter(self, jitter_val=1e-3):
def _unsqueeze_batch(self, dim):
x1 = self.x1.unsqueeze(dim)
x2 = self.x2.unsqueeze(dim)
return self.__class__(x1, x2, kernel=self.kernel, last_dim_is_batch=self.last_dim_is_batch, **self.params,)
return self.__class__(
x1,
x2,
kernel=self.kernel,
last_dim_is_batch=self.last_dim_is_batch,
**self.params,
)

@cached(name="kernel_diag")
def diag(self):
Expand Down Expand Up @@ -291,7 +319,13 @@ def evaluate_kernel(self):
with settings.lazily_evaluate_kernels(False):
temp_active_dims = self.kernel.active_dims
self.kernel.active_dims = None
res = self.kernel(x1, x2, diag=False, last_dim_is_batch=self.last_dim_is_batch, **self.params,)
res = self.kernel(
x1,
x2,
diag=False,
last_dim_is_batch=self.last_dim_is_batch,
**self.params,
)
self.kernel.active_dims = temp_active_dims

# Check the size of the output
Expand All @@ -315,7 +349,13 @@ def repeat(self, *repeats):

x1 = self.x1.repeat(*batch_repeat, row_repeat, 1)
x2 = self.x2.repeat(*batch_repeat, col_repeat, 1)
return self.__class__(x1, x2, kernel=self.kernel, last_dim_is_batch=self.last_dim_is_batch, **self.params,)
return self.__class__(
x1,
x2,
kernel=self.kernel,
last_dim_is_batch=self.last_dim_is_batch,
**self.params,
)

def representation(self):
# If we're checkpointing the kernel, we'll use chunked _matmuls defined in LazyEvaluatedKernelTensor
Expand Down

0 comments on commit 7289e7d

Please sign in to comment.