Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lazy parameters and bijectors with metaclasses #59

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ jobs:
flake8 . --count --show-source --statistics
- name: Check types with mypy
run: |
mypy flowtorch
mypy --disallow-untyped-defs flowtorch
- name: Test with pytest
run: |
pytest --cov=tests --cov-report=xml -W ignore::DeprecationWarning tests/
Expand Down
4 changes: 4 additions & 0 deletions flowtorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# SPDX-License-Identifier: MIT

from flowtorch.lazy import Lazy, LazyMeta

__all__ = ["Lazy", "LazyMeta"]
13 changes: 5 additions & 8 deletions flowtorch/bijectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import cast, List, Tuple

import torch
import torch.distributions as dist

# TODO: Autogenerate this from script!
from flowtorch.bijectors.affine_autoregressive import AffineAutoregressive
Expand Down Expand Up @@ -34,12 +33,12 @@
]


def isbijector(cls):
def isbijector(cls: type) -> bool:
# A class must inherit from flowtorch.Bijector to be considered a valid bijector
return issubclass(cls, Bijector)


def standard_bijector(cls):
def standard_bijector(cls: type) -> bool:
# "Standard bijectors" are the ones we can perform standard automated tests upon
return (
inspect.isclass(cls)
Expand Down Expand Up @@ -67,12 +66,10 @@ def standard_bijector(cls):
for bij_name, cls in standard_bijectors:
# TODO: Use factored out version of the following
# Define plan for flow
bij = cls()
event_dim = max(bij.domain.event_dim, 1)
event_dim = max(cls.domain.event_dim, 1) # type: ignore
event_shape = event_dim * [4]
base_dist = dist.Normal(torch.zeros(event_shape), torch.ones(event_shape))
flow = bij(base_dist)
bij = flow.bijector
# base_dist = dist.Normal(torch.zeros(event_shape), torch.ones(event_shape))
bij = cls(torch.Size(event_shape))

try:
y = torch.randn(*bij.forward_shape(event_shape))
Expand Down
46 changes: 24 additions & 22 deletions flowtorch/bijectors/affine_autoregressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,45 +3,54 @@

from typing import cast, Optional, Tuple

import flowtorch.params
import flowtorch
import flowtorch.parameters
import torch
import torch.distributions.constraints as constraints
from flowtorch.bijectors.base import Bijector
from flowtorch.ops import clamp_preserve_gradients
from torch.distributions.utils import _sum_rightmost


class AffineAutoregressive(Bijector):
# "Default" event shape is to operate on vectors
domain = constraints.real_vector
codomain = constraints.real_vector

# TODO: Remove when bijector/params type system is implemented
autoregressive = True

def __init__(
self,
param_fn: Optional[flowtorch.params.DenseAutoregressive] = None,
shape: torch.Size,
params: Optional[flowtorch.Lazy] = None,
context_size: int = 0,
*,
log_scale_min_clip: float = -5.0,
log_scale_max_clip: float = 3.0,
sigmoid_bias: float = 2.0,
context_size: int = 0,
) -> None:
# Event shape is determined by `shape` argument
self.domain = constraints.independent(constraints.real, len(shape))
self.codomain = constraints.independent(constraints.real, len(shape))

# currently only DenseAutoregressive has a `permutation` buffer
if not param_fn:
param_fn = flowtorch.params.DenseAutoregressive()
if not params:
params = flowtorch.parameters.DenseAutoregressive() # type: ignore

super().__init__(param_fn=param_fn)
super().__init__(shape, params, context_size)
self.log_scale_min_clip = log_scale_min_clip
self.log_scale_max_clip = log_scale_max_clip
self.sigmoid_bias = sigmoid_bias
self._context_size = context_size

def _forward(
self,
x: torch.Tensor,
context: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# TODO: lift into type system using thunk, see similar pattern for
# Param/ParamImpl
params = self.params
assert params is not None

mean, log_scale = params(x, context=context)
log_scale = clamp_preserve_gradients(
log_scale, self.log_scale_min_clip, self.log_scale_max_clip
Expand All @@ -55,8 +64,6 @@ def _inverse(
y: torch.Tensor,
context: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# TODO: lift into type system using thunk, see similar pattern for
# Param/ParamImpl
params = self.params
assert params is not None

Expand All @@ -74,7 +81,7 @@ def _inverse(
min=self.log_scale_min_clip,
max=self.log_scale_max_clip,
)
) # * 10
)
mean = mean[..., idx]
x[..., idx] = (y[..., idx] - mean) * inverse_scale

Expand All @@ -86,8 +93,6 @@ def _log_abs_det_jacobian(
y: torch.Tensor,
context: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# TODO: lift into type system using thunk, see similar pattern for
# Param/ParamImpl
params = self.params
assert params is not None

Expand All @@ -96,11 +101,8 @@ def _log_abs_det_jacobian(
log_scale = clamp_preserve_gradients(
log_scale, self.log_scale_min_clip, self.log_scale_max_clip
)
return log_scale.sum(-1)

def param_shapes(
self, dist: torch.distributions.Distribution
) -> Tuple[torch.Size, torch.Size]:
# A mean and log variance for every dimension of base distribution
# TODO: Change this to reflect base dimension!
return torch.Size([]), torch.Size([])
return _sum_rightmost(log_scale, self.domain.event_dim)

def param_shapes(self, shape: torch.Size) -> Tuple[torch.Size, torch.Size]:
# A mean and log variance for every dimension of the event shape
return shape, shape
13 changes: 11 additions & 2 deletions flowtorch/bijectors/affine_fixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import math
from typing import Optional

import flowtorch
import torch
from flowtorch.bijectors.base import Bijector

Expand All @@ -16,8 +17,16 @@ class AffineFixed(Bijector):
"""

# TODO: Handle non-scalar loc and scale with correct broadcasting semantics
def __init__(self, loc=0.0, scale=1.0) -> None:
super().__init__(param_fn=None)
def __init__(
self,
shape: torch.Size,
params: Optional[flowtorch.Lazy] = None,
context_size: int = 0,
*,
loc: float = 0.0,
scale: float = 1.0
) -> None:
super().__init__(shape, params, context_size)
self.loc = loc
self.scale = scale

Expand Down
140 changes: 20 additions & 120 deletions flowtorch/bijectors/base.py
Original file line number Diff line number Diff line change
@@ -1,78 +1,45 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# SPDX-License-Identifier: MIT
import weakref
from copy import deepcopy
from typing import Optional, Sequence, Union, cast
from typing import Optional, Sequence, Union

import flowtorch
import flowtorch.distributions
import flowtorch.params
import flowtorch.parameters
import torch
import torch.distributions
from flowtorch.params import ParamsModule
from flowtorch.parameters import Parameters
from torch.distributions import constraints


class Bijector(object):
_inv: Optional[Union[weakref.ReferenceType, "Bijector"]] = None
class Bijector(metaclass=flowtorch.LazyMeta):
# _inv: Optional[Union[weakref.ReferenceType, "Bijector"]] = None
codomain: constraints.Constraint = constraints.real
domain: constraints.Constraint = constraints.real
identity_initialization: bool = True
autoregressive: bool = False
_context_size: int
event_dim: int = 0
_params: Optional[flowtorch.params.ParamsModule] = None
_params: Optional[Union[Parameters, torch.nn.ModuleList]] = None

def __init__(
self,
param_fn: Optional[flowtorch.params.Params] = None,
shape: torch.Size,
params: Optional[flowtorch.Lazy] = None,
context_size: int = 0,
) -> None:
super().__init__()
self.param_fn = param_fn
self._context_size = context_size

def __call__(
self, base_dist: torch.distributions.Distribution
) -> flowtorch.distributions.TransformedDistribution:
"""
Returns the distribution formed by passing dist through the bijection
"""
if self.params is not None:
raise RuntimeError(
"Cannot instantiate a Bijector that has a non-None params attribute."
)

# If the input is a distribution then return transformed distribution
if isinstance(base_dist, torch.distributions.Distribution):
# Create transformed distribution
# TODO: Check that if bijector is autoregressive then parameters are as
# well Possibly do this in simplex.Bijector.__init__ and call from
# simple.bijectors.*.__init__
input_shape = (
base_dist.batch_shape + base_dist.event_shape # pyre-ignore[16]
)

# Instantiate hypernets on a copy of bijector, so self remains just a "plan"
self_copy = deepcopy(self)
if self_copy.param_fn is not None:
self_copy.params = self_copy.param_fn(
input_shape, self.param_shapes(base_dist), self_copy._context_size
) # <= this is where hypernets etc. are instantiated
new_dist = flowtorch.distributions.TransformedDistribution(
base_dist, self_copy
)
return new_dist

# TODO: Handle other types of inputs such as tensors
else:
raise TypeError(f"Bijector called with invalid type: {type(base_dist)}")
# Instantiate parameters (tensor, hypernets, etc.)
if params is not None:
shapes = self.param_shapes(shape)
self._params = params(shape, shapes, self._context_size) # type: ignore

@property
def params(self) -> Optional[ParamsModule]:
def params(self) -> Optional[Union[Parameters, torch.nn.ModuleList]]:
return self._params

@params.setter
def params(self, value: Optional[ParamsModule]):
def params(self, value: Optional[Union[Parameters, torch.nn.ModuleList]]) -> None:
self._params = value

def forward(
Expand Down Expand Up @@ -138,14 +105,13 @@ def _log_abs_det_jacobian(
# self.event_dim may be > 0 for derived classes!
return torch.zeros_like(x)

def param_shapes(
self, dist: torch.distributions.Distribution
) -> Sequence[torch.Size]:
def param_shapes(self, shape: torch.Size) -> Sequence[torch.Size]:
"""
Abstract method to return shapes of parameters
"""
raise NotImplementedError

"""
def inv(self) -> "Bijector":
if self._inv is not None:
# TODO: remove casting without failing mypy
Expand All @@ -154,87 +120,21 @@ def inv(self) -> "Bijector":
inv = _InverseBijector(self)
self._inv = weakref.ref(inv)
return inv
"""

def __repr__(self) -> str:
return self.__class__.__name__ + "()"

def forward_shape(self, shape):
def forward_shape(self, shape: torch.Size) -> torch.Size:
"""
Infers the shape of the forward computation, given the input shape.
Defaults to preserving shape.
"""
return shape

def inverse_shape(self, shape):
def inverse_shape(self, shape: torch.Size) -> torch.Size:
"""
Infers the shapes of the inverse computation, given the output shape.
Defaults to preserving shape.
"""
return shape


class _InverseBijector(Bijector):
_inv: Bijector
"""
Inverts a single :class:`Bijector`.
This class is private; please instead use the ``Bijector.inv`` property.
"""

def __init__(self, bijector: Bijector):
super(_InverseBijector, self).__init__(param_fn=bijector.param_fn)
self._inv = bijector
self.param_fn = bijector.param_fn
self.domain = bijector.codomain
self.codomain = bijector.domain
self._context_size = bijector._context_size

@property
def inv(self):
return self._inv

@property
def params(self):
return self.inv.params

@params.setter
def params(self, value):
self.inv.params = value

def __eq__(self, other):
if not isinstance(other, _InverseBijector):
return False
assert self._inv is not None
return self._inv == other._inv

def _forward(
self,
x: torch.Tensor,
context: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self._inv.inverse(x, context)

def _inverse(
self,
y: torch.Tensor,
context: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self._inv.forward(y, context)

def _log_abs_det_jacobian(
self,
x: torch.Tensor,
y: torch.Tensor,
context: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return -self._inv.log_abs_det_jacobian(y, x, context)

def param_shapes(
self, dist: torch.distributions.Distribution
) -> Sequence[torch.Size]:
return self._inv.param_shapes(dist)

def forward_shape(self, shape):
return self._inv.inverse_shape(shape)

def inverse_shape(self, shape):
return self._inv.forward_shape(shape)
Loading