Skip to content

Latest commit

 

History

History
50 lines (42 loc) · 2.73 KB

custom_linear_operators.rst

File metadata and controls

50 lines (42 loc) · 2.73 KB

Writing Your Own LinearOpeators

In order to define a new LinearOperator class, a user must define at a minimum the following methods (in each example, A denotes the matrix that the LinearOperator represents)

  • ~linear_operator.operators.LinearOperator._matmul, which performs a matrix multiplication AB
  • ~linear_operator.operators.LinearOperator._size, which returns a torch.Size containing the dimensions of A.
  • ~linear_operator.operators.LinearOperator._transpose_nonbatch, which returns a transposed version of the LinearOperator

In addition to these, the following methods should be implemented for maximum efficiency

  • ~linear_operator.operators.LinearOperator._bilinear_derivative, which computes the derivative of a quadratic form with the LinearOperator's representation (e.g. ∂(bTA(θ)c)/∂θ).
  • ~linear_operator.operators.LinearOperator._get_indices, which returns a torch.Tensor containing elements that are given by various tensor indices.
  • ~linear_operator.operators.LinearOperator._expand_batch, which expands the batch dimensions of LinearOperators.
  • ~linear_operator.operators.LinearOperator._check_args, which performs error checking on the arguments supplied to the LinearOperator constructor.

In addition to these, a LinearOperator may need to define the following functions if it does anything interesting with the batch dimensions (e.g. sums along them, adds additional ones, etc): ~linear_operator.operators.LinearOperator._unsqueeze_batch, ~linear_operator.operators.LinearOperator._getitem, and ~linear_operator.operators.LinearOperator._permute_batch. See the documentation for these methods for details.

Note

The base LinearOperator class provides default implementations of many other operations in order to mimic the behavior of a standard tensor as closely as possible. For example, we provide default implementations of ~linear_operator.operators.LinearOperator.__getitem__, ~linear_operator.operators.LinearOperator.__add__, etc that either make use of other linear operators or exploit the functions that must be defined above.

Rather than overriding the public methods, we recommend that you override the private versions associated with these methods (e.g. - write a custom _getitem verses a custom __getitem__). This is because the public methods do quite a bit of error checking and casing that doesn't need to be repeated.