Skip to content

Commit

Permalink
Fix direction scaling
Browse files Browse the repository at this point in the history
  • Loading branch information
marcellodebernardi committed Aug 30, 2019
1 parent 7c92488 commit 8d34610
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
6 changes: 3 additions & 3 deletions loss_landscapes/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def random_line(model_start: typing.Union[torch.nn.Module, ModelWrapper], metric
else:
raise AttributeError('Unsupported normalization argument. Supported values are model, layer, and filter')

direction.mul_(((start_point.model_norm() / distance) / steps) / direction.model_norm())
direction.mul_(((start_point.model_norm() * distance) / steps) / direction.model_norm())

data_values = []
for i in range(steps):
Expand Down Expand Up @@ -290,8 +290,8 @@ def random_plane(model: typing.Union[torch.nn.Module, ModelWrapper], metric: Met
raise AttributeError('Unsupported normalization argument. Supported values are model, layer, and filter')

# scale to match steps and total distance
dir_one.mul_(((start_point.model_norm() / distance) / steps) / dir_one.model_norm())
dir_two.mul_(((start_point.model_norm() / distance) / steps) / dir_two.model_norm())
dir_one.mul_(((start_point.model_norm() * distance) / steps) / dir_one.model_norm())
dir_two.mul_(((start_point.model_norm() * distance) / steps) / dir_two.model_norm())
# Move start point so that original start params will be in the center of the plot
dir_one.mul_(steps / 2)
dir_two.mul_(steps / 2)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

setup(
name='loss_landscapes',
version='3.0.6',
version='3.0.7',
packages=find_packages(exclude='tests'),
url='https://github.com/marcellodebernardi/loss-landscapes',
license='MIT',
Expand Down

0 comments on commit 8d34610

Please sign in to comment.