Skip to content

Commit

Permalink
Merge pull request #9 from jmclong/v1.0.1
Browse files Browse the repository at this point in the history
chore: cut 1.0.1, add test, fix deprecation warning on meshgrid
  • Loading branch information
jmclong committed Oct 29, 2022
2 parents b2bc449 + dfcfc0d commit 9ec87f3
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 4 deletions.
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
author = 'Joshua Mark Coley Long'

# The full version, including alpha/beta/rc tags
release = '1.0.0'
release = '1.0.1'


# -- General configuration ---------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion rff/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
import rff.layers
import rff.dataloader

__version__ = '1.0.0'
__version__ = '1.0.1'
2 changes: 1 addition & 1 deletion rff/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def rectangular_coordinates(size: tuple) -> Tensor:
"""
def linspace_func(nx): return torch.linspace(0.0, 1.0, nx)
linspaces = map(linspace_func, size)
coordinates = torch.meshgrid(*linspaces)
coordinates = torch.meshgrid(*linspaces, indexing='ij')
return torch.stack(coordinates, dim=-1)


Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
long_description = (this_directory / "README.md").read_text()

setup(name='random-fourier-features-pytorch',
version='1.0.0',
version='1.0.1',
description='Random Fourier Features for PyTorch',
long_description=long_description,
long_description_content_type='text/markdown',
Expand Down
9 changes: 9 additions & 0 deletions test/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,15 @@ def test_gaussian_encoding(device):
decimal=5)


@pytest.mark.parametrize('device', ['cpu', 'cuda'])
def test_gaussian_encoding_no_unfreeze(device):
check_cuda(device)
b = rff.functional.sample_b(1.0, (256, 2)).to(device)
layer = rff.layers.GaussianEncoding(b=b).to(device)
layer.requires_grad = True
assert layer.b.requires_grad != True


@pytest.mark.parametrize('device', ['cpu', 'cuda'])
def test_gaussian_encoding_register_buffer(device):
check_cuda(device)
Expand Down

0 comments on commit 9ec87f3

Please sign in to comment.