Skip to content

Commit

Permalink
Merge 6bf5485 into 2c632c9
Browse files Browse the repository at this point in the history
  • Loading branch information
schaefertim committed Jul 27, 2021
2 parents 2c632c9 + 6bf5485 commit 2723d39
Show file tree
Hide file tree
Showing 25 changed files with 246 additions and 879 deletions.
334 changes: 62 additions & 272 deletions backpack/core/derivatives/basederivatives.py

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions backpack/core/derivatives/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,9 @@ def _ifgo_jac_t_mat_prod(
)
return IFGO_prod

def hessian_is_zero(self, module: LSTM) -> bool: # noqa: D102
return False

def _jac_mat_prod(
self,
module: LSTM,
Expand Down
3 changes: 3 additions & 0 deletions backpack/core/derivatives/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ def _check_parameters(module: RNN) -> None:
if module.bidirectional is not False:
raise NotImplementedError("only bidirectional = False is supported")

def hessian_is_zero(self, module: RNN) -> bool: # noqa: D102
return False

@staticmethod
def _a_jac_t_mat_prod(
output: Tensor,
Expand Down
111 changes: 56 additions & 55 deletions backpack/core/derivatives/shape_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,16 @@ def check_shape(mat: Tensor, like: Tensor, diff: int = 1) -> None:
)


def _check_same_V_dim(mat1, mat2):
def check_same_V_dim(mat1, mat2):
"""Check whether V dim (first dim) matches.
Args:
mat1: first tensor
mat2: second tensor
Raises:
RuntimeError: if V dim (first dim) doesn't match
"""
V1, V2 = mat1.shape[0], mat2.shape[0]
if V1 != V2:
raise RuntimeError("Number of vectors changed. Got {} and {}".format(V1, V2))
Expand All @@ -73,9 +82,19 @@ def _check_like(mat, module, name, diff=1, *args, **kwargs):
return check_shape(mat, compare, diff=diff)


def _check_like_with_sum_batch(mat, module, name, sum_batch=True, *args, **kwargs):
def check_like_with_sum_batch(mat, module, name, sum_batch=True, *args, **kwargs):
"""Checks shape, considers sum_batch.
Args:
mat: matrix to multiply
module: module
name: parameter to operate on: module.name
sum_batch: whether to consider with or without sum
*args: ignored
**kwargs: ignored
"""
diff = 1 if sum_batch else 2
return check_shape(mat, getattr(module, name), diff=diff)
check_shape(mat, getattr(module, name), diff=diff)


def _same_dim_as(mat, module, name, *args, **kwargs):
Expand Down Expand Up @@ -133,15 +152,6 @@ def _wrapped_mat_prod_accept_vectors(
vec_criterion=same_dim_as_output,
)

weight_jac_t_mat_prod_accept_vectors = functools.partial(
_mat_prod_accept_vectors,
vec_criterion=same_dim_as_output,
)
bias_jac_t_mat_prod_accept_vectors = functools.partial(
_mat_prod_accept_vectors,
vec_criterion=same_dim_as_output,
)

jac_mat_prod_accept_vectors = functools.partial(
_mat_prod_accept_vectors,
vec_criterion=same_dim_as_input,
Expand Down Expand Up @@ -181,7 +191,7 @@ def wrapped_mat_prod_check_shapes(self, module, g_inp, g_out, mat, *args, **kwar
in_check(mat, module, *args, **kwargs)
mat_out = mat_prod(self, module, g_inp, g_out, mat, *args, **kwargs)
out_check(mat_out, module, *args, **kwargs)
_check_same_V_dim(mat_out, mat)
check_same_V_dim(mat_out, mat)

return mat_out

Expand All @@ -193,21 +203,6 @@ def wrapped_mat_prod_check_shapes(self, module, g_inp, g_out, mat, *args, **kwar
shape_like_input = functools.partial(_check_like, name="input0")
shape_like_weight = functools.partial(_check_like, name="weight")
shape_like_bias = functools.partial(_check_like, name="bias")
shape_like_weight_with_sum_batch = functools.partial(
_check_like_with_sum_batch, name="weight"
)
shape_like_bias_with_sum_batch = functools.partial(
_check_like_with_sum_batch, name="bias"
)
shape_like_bias_rnn_with_sum_batch = functools.partial(
_check_like_with_sum_batch, name="bias_ih_l0"
)
shape_like_weight_ih_with_sum_batch = functools.partial(
_check_like_with_sum_batch, name="weight_ih_l0"
)
shape_like_weight_hh_with_sum_batch = functools.partial(
_check_like_with_sum_batch, name="weight_hh_l0"
)

# decorators for shape checking
jac_mat_prod_check_shapes = functools.partial(
Expand All @@ -226,33 +221,6 @@ def wrapped_mat_prod_check_shapes(self, module, g_inp, g_out, mat, *args, **kwar
mat_prod_check_shapes, in_check=shape_like_output, out_check=shape_like_input
)


weight_jac_t_mat_prod_check_shapes = functools.partial(
mat_prod_check_shapes,
in_check=shape_like_output,
out_check=shape_like_weight_with_sum_batch,
)
bias_jac_t_mat_prod_check_shapes = functools.partial(
mat_prod_check_shapes,
in_check=shape_like_output,
out_check=shape_like_bias_with_sum_batch,
)
bias_rnn_jac_t_mat_prod_check_shapes = functools.partial(
mat_prod_check_shapes,
in_check=shape_like_output,
out_check=shape_like_bias_rnn_with_sum_batch,
)
weight_ih_jac_t_mat_prod_check_shapes = functools.partial(
mat_prod_check_shapes,
in_check=shape_like_output,
out_check=shape_like_weight_ih_with_sum_batch,
)
weight_hh_jac_t_mat_prod_check_shapes = functools.partial(
mat_prod_check_shapes,
in_check=shape_like_output,
out_check=shape_like_weight_hh_with_sum_batch,
)

###############################################################################
# Wrapper for second-order extensions #
###############################################################################
Expand Down Expand Up @@ -327,3 +295,36 @@ def _new_hessian_mat_prod(mat):
return _new_hessian_mat_prod

return _wrapped_make_hessian_mat_prod


def param_mjp_accept_vectors(mat_prod: Callable[..., Tensor]) -> Callable[..., Tensor]:
"""Add support for vectors to matrix products.
vec_criterion(mat, module) returns if mat is a vector.
Args:
mat_prod: Function that processes multiple vectors in format of a matrix.
Returns:
Wrapped ``mat_prod`` function that processes multiple vectors in format of
a matrix, and supports vector-shaped inputs which are internally converted
to the correct format.
Preserves format of input:
If the input format is a vector, the output format is a vector.
If the input format is a matrix, the output format is a matrix.
"""

@functools.wraps(mat_prod)
def _wrapped_mat_prod_accept_vectors(
self, param_str, module, g_inp, g_out, mat, *args, **kwargs
):
is_vec = same_dim_as_output(mat, module)
mat_in = mat if not is_vec else _add_V_dim(mat)
mat_out = mat_prod(
self, param_str, module, g_inp, g_out, mat_in, *args, **kwargs
)
mat_out = mat_out if not is_vec else _remove_V_dim(mat_out)

return mat_out

return _wrapped_mat_prod_accept_vectors
6 changes: 2 additions & 4 deletions backpack/extensions/curvmatprod/ggnmp/batchnorm1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@ def weight(self, ext, module, g_inp, g_out, backproped):
def weight_ggnmp(mat):
result = self.derivatives.weight_jac_mat_prod(module, g_inp, g_out, mat)
result = h_out_mat_prod(result)
result = self.derivatives.weight_jac_t_mat_prod(
module, g_inp, g_out, result
)
result = self.derivatives.param_mjp("weight", module, g_inp, g_out, result)

return result

Expand All @@ -28,7 +26,7 @@ def bias(self, ext, module, g_inp, g_out, backproped):
def bias_ggnmp(mat):
result = self.derivatives.bias_jac_mat_prod(module, g_inp, g_out, mat)
result = h_out_mat_prod(result)
result = self.derivatives.bias_jac_t_mat_prod(module, g_inp, g_out, result)
result = self.derivatives.param_mjp("bias", module, g_inp, g_out, result)

return result

Expand Down
6 changes: 2 additions & 4 deletions backpack/extensions/curvmatprod/ggnmp/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@ def weight(self, ext, module, g_inp, g_out, backproped):
def weight_ggnmp(mat):
result = self.derivatives.weight_jac_mat_prod(module, g_inp, g_out, mat)
result = h_out_mat_prod(result)
result = self.derivatives.weight_jac_t_mat_prod(
module, g_inp, g_out, result
)
result = self.derivatives.param_mjp("weight", module, g_inp, g_out, result)

return result

Expand All @@ -26,7 +24,7 @@ def bias(self, ext, module, g_inp, g_out, backproped):
def bias_ggnmp(mat):
result = self.derivatives.bias_jac_mat_prod(module, g_inp, g_out, mat)
result = h_out_mat_prod(result)
result = self.derivatives.bias_jac_t_mat_prod(module, g_inp, g_out, result)
result = self.derivatives.param_mjp("bias", module, g_inp, g_out, result)

return result

Expand Down
6 changes: 2 additions & 4 deletions backpack/extensions/curvmatprod/ggnmp/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@ def weight(self, ext, module, g_inp, g_out, backproped):
def weight_ggnmp(mat):
result = self.derivatives.weight_jac_mat_prod(module, g_inp, g_out, mat)
result = h_out_mat_prod(result)
result = self.derivatives.weight_jac_t_mat_prod(
module, g_inp, g_out, result
)
result = self.derivatives.param_mjp("weight", module, g_inp, g_out, result)

return result

Expand All @@ -26,7 +24,7 @@ def bias(self, ext, module, g_inp, g_out, backproped):
def bias_ggnmp(mat):
result = self.derivatives.bias_jac_mat_prod(module, g_inp, g_out, mat)
result = h_out_mat_prod(result)
result = self.derivatives.bias_jac_t_mat_prod(module, g_inp, g_out, result)
result = self.derivatives.param_mjp("bias", module, g_inp, g_out, result)

return result

Expand Down
6 changes: 2 additions & 4 deletions backpack/extensions/curvmatprod/hmp/batchnorm1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@ def weight(self, ext, module, g_inp, g_out, backproped):
def weight_hmp(mat):
result = self.derivatives.weight_jac_mat_prod(module, g_inp, g_out, mat)
result = h_out_mat_prod(result)
result = self.derivatives.weight_jac_t_mat_prod(
module, g_inp, g_out, result
)
result = self.derivatives.param_mjp("weight", module, g_inp, g_out, result)

return result

Expand All @@ -28,7 +26,7 @@ def bias(self, ext, module, g_inp, g_out, backproped):
def bias_hmp(mat):
result = self.derivatives.bias_jac_mat_prod(module, g_inp, g_out, mat)
result = h_out_mat_prod(result)
result = self.derivatives.bias_jac_t_mat_prod(module, g_inp, g_out, result)
result = self.derivatives.param_mjp("bias", module, g_inp, g_out, result)

return result

Expand Down
6 changes: 2 additions & 4 deletions backpack/extensions/curvmatprod/hmp/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@ def weight(self, ext, module, g_inp, g_out, backproped):
def weight_hmp(mat):
result = self.derivatives.weight_jac_mat_prod(module, g_inp, g_out, mat)
result = h_out_mat_prod(result)
result = self.derivatives.weight_jac_t_mat_prod(
module, g_inp, g_out, result
)
result = self.derivatives.param_mjp("weight", module, g_inp, g_out, result)

return result

Expand All @@ -26,7 +24,7 @@ def bias(self, ext, module, g_inp, g_out, backproped):
def bias_hmp(mat):
result = self.derivatives.bias_jac_mat_prod(module, g_inp, g_out, mat)
result = h_out_mat_prod(result)
result = self.derivatives.bias_jac_t_mat_prod(module, g_inp, g_out, result)
result = self.derivatives.param_mjp("bias", module, g_inp, g_out, result)

return result

Expand Down
6 changes: 2 additions & 4 deletions backpack/extensions/curvmatprod/hmp/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@ def weight(self, ext, module, g_inp, g_out, backproped):
def weight_hmp(mat):
result = self.derivatives.weight_jac_mat_prod(module, g_inp, g_out, mat)
result = h_out_mat_prod(result)
result = self.derivatives.weight_jac_t_mat_prod(
module, g_inp, g_out, result
)
result = self.derivatives.param_mjp("weight", module, g_inp, g_out, result)

return result

Expand All @@ -26,7 +24,7 @@ def bias(self, ext, module, g_inp, g_out, backproped):
def bias_hmp(mat):
result = self.derivatives.bias_jac_mat_prod(module, g_inp, g_out, mat)
result = h_out_mat_prod(result)
result = self.derivatives.bias_jac_t_mat_prod(module, g_inp, g_out, result)
result = self.derivatives.param_mjp("bias", module, g_inp, g_out, result)

return result

Expand Down
6 changes: 2 additions & 4 deletions backpack/extensions/curvmatprod/pchmp/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@ def weight(self, ext, module, g_inp, g_out, backproped):
def weight_pchmp(mat):
result = self.derivatives.weight_jac_mat_prod(module, g_inp, g_out, mat)
result = h_out_mat_prod(result)
result = self.derivatives.weight_jac_t_mat_prod(
module, g_inp, g_out, result
)
result = self.derivatives.param_mjp("weight", module, g_inp, g_out, result)

return result

Expand All @@ -26,7 +24,7 @@ def bias(self, ext, module, g_inp, g_out, backproped):
def bias_pchmp(mat):
result = self.derivatives.bias_jac_mat_prod(module, g_inp, g_out, mat)
result = h_out_mat_prod(result)
result = self.derivatives.bias_jac_t_mat_prod(module, g_inp, g_out, result)
result = self.derivatives.param_mjp("bias", module, g_inp, g_out, result)

return result

Expand Down
6 changes: 2 additions & 4 deletions backpack/extensions/curvmatprod/pchmp/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@ def weight(self, ext, module, g_inp, g_out, backproped):
def weight_pchmp(mat):
result = self.derivatives.weight_jac_mat_prod(module, g_inp, g_out, mat)
result = h_out_mat_prod(result)
result = self.derivatives.weight_jac_t_mat_prod(
module, g_inp, g_out, result
)
result = self.derivatives.param_mjp("weight", module, g_inp, g_out, result)

return result

Expand All @@ -26,7 +24,7 @@ def bias(self, ext, module, g_inp, g_out, backproped):
def bias_pchmp(mat):
result = self.derivatives.bias_jac_mat_prod(module, g_inp, g_out, mat)
result = h_out_mat_prod(result)
result = self.derivatives.bias_jac_t_mat_prod(module, g_inp, g_out, result)
result = self.derivatives.param_mjp("bias", module, g_inp, g_out, result)

return result

Expand Down
3 changes: 2 additions & 1 deletion backpack/extensions/firstorder/batch_grad/batch_grad_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ def param_function(
Scaled individual gradients
"""
subsampling = ext.get_subsampling()
return getattr(self._derivatives, f"{param_str}_jac_t_mat_prod")(
return self._derivatives.param_mjp(
param_str,
module,
g_inp,
g_out,
Expand Down
4 changes: 2 additions & 2 deletions backpack/extensions/firstorder/batch_l2_grad/batch_l2_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ def param_function(
"""
param_dims: List[int] = list(range(1, 1 + getattr(module, param_str).dim()))
return (
getattr(self.derivatives, f"{param_str}_jac_t_mat_prod")(
module, g_inp, g_out, g_out[0], sum_batch=False
self.derivatives.param_mjp(
param_str, module, g_inp, g_out, g_out[0], sum_batch=False
)
** 2
).sum(param_dims)
Expand Down
8 changes: 4 additions & 4 deletions backpack/extensions/firstorder/gradient/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ def __init__(self, derivatives, params):
setattr(self, param_str, self._make_param_function(param_str))
super().__init__(params=params)

def _make_param_function(self, param):
def _make_param_function(self, param_str):
"""Creates a function that calculates gradient wrt param.
Args:
param(str): name of parameter
param_str: name of parameter
Returns:
function: function that calculates gradient wrt param
Expand All @@ -58,8 +58,8 @@ def param_function(ext, module, g_inp, g_out, bpQuantities):
Returns:
torch.Tensor: gradient of the batch, similar to autograd
"""
return getattr(self.derivatives, f"{param}_jac_t_mat_prod")(
module, g_inp, g_out, g_out[0], sum_batch=True
return self.derivatives.param_mjp(
param_str, module, g_inp, g_out, g_out[0], sum_batch=True
)

return param_function
4 changes: 2 additions & 2 deletions backpack/extensions/firstorder/sum_grad_squared/sgs_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def param_function(
sum_grad_squared
"""
return (
getattr(self.derivatives, f"{param_str}_jac_t_mat_prod")(
module, g_inp, g_out, g_out[0], sum_batch=False
self.derivatives.param_mjp(
param_str, module, g_inp, g_out, g_out[0], sum_batch=False
)
** 2
).sum(self.N_axis)
Expand Down

0 comments on commit 2723d39

Please sign in to comment.