diff --git a/backpack/extensions/firstorder/batch_dot_grad/base.py b/backpack/extensions/firstorder/batch_dot_grad/base.py new file mode 100644 index 00000000..041e1447 --- /dev/null +++ b/backpack/extensions/firstorder/batch_dot_grad/base.py @@ -0,0 +1,31 @@ +"""Base class for extension ``BatchDotGrad``.""" + +import torch + +from backpack.extensions.firstorder.base import FirstOrderModuleExtension + + +class BatchDotGradBase(FirstOrderModuleExtension): + def __init__(self, derivatives, params=None): + self.derivatives = derivatives + super().__init__(params=params) + + def bias(self, ext, module, g_inp, g_out, bpQuantities): + grad_batch = self.derivatives.bias_jac_t_mat_prod( + module, g_inp, g_out, g_out[0], sum_batch=False + ) + return self.pairwise_dot(grad_batch) + + def weight(self, ext, module, g_inp, g_out, bpQuantities): + grad_batch = self.derivatives.weight_jac_t_mat_prod( + module, g_inp, g_out, g_out[0], sum_batch=False + ) + return self.pairwise_dot(grad_batch) + + @staticmethod + def pairwise_dot(grad_batch): + """Compute pairwise dot products of individual gradients.""" + # flatten all feature dimensions + grad_batch_flat = grad_batch.flatten(start_dim=1) + # pairwise dot product + return torch.einsum("if,jf->ij", grad_batch_flat, grad_batch_flat) diff --git a/backpack/extensions/firstorder/batch_dot_grad/linear.py b/backpack/extensions/firstorder/batch_dot_grad/linear.py index 59c39419..c7af76ce 100644 --- a/backpack/extensions/firstorder/batch_dot_grad/linear.py +++ b/backpack/extensions/firstorder/batch_dot_grad/linear.py @@ -1,31 +1,7 @@ -import torch - from backpack.core.derivatives.linear import LinearDerivatives -from backpack.extensions.firstorder.base import FirstOrderModuleExtension +from backpack.extensions.firstorder.batch_dot_grad.base import BatchDotGradBase -class BatchDotGradLinear(FirstOrderModuleExtension): +class BatchDotGradLinear(BatchDotGradBase): def __init__(self): - self.derivatives = LinearDerivatives() - super().__init__(params=["bias", "weight"]) - - def bias(self, ext, module, g_inp, g_out, bpQuantities): - # Return value will be stored in savefield of extension - grad_batch = self.derivatives.bias_jac_t_mat_prod( - module, g_inp, g_out, g_out[0], sum_batch=False - ) - return self.pairwise_dot(grad_batch) - - def weight(self, ext, module, g_inp, g_out, bpQuantities): - # Return value will be stored in savefield of extension - grad_batch = self.derivatives.weight_jac_t_mat_prod( - module, g_inp, g_out, g_out[0], sum_batch=False - ) - return self.pairwise_dot(grad_batch) - - @staticmethod - def pairwise_dot(grad_batch): - # flatten all feature dimensions - grad_batch_flat = grad_batch.flatten(start_dim=1) - # pairwise dot product - return torch.einsum("if,jf->ij", grad_batch_flat, grad_batch_flat) + super().__init__(derivatives=LinearDerivatives(), params=["bias", "weight"])