Skip to content

gpleiss/cola

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Compositional Linear Algebra (CoLA)

Documentation Open In Colab tests codecov

CoLA is a framework for scalable linear algebra, automatically exploiting the structure often found in machine learning problems and beyond. CoLA supports both PyTorch and JAX.

Installation

pip install git+https://github.com/wilson-labs/cola.git

Features in CoLA

  • Provides several compositional rules to exploit problem structure through multiple dispatch.
  • Works with PyTorch and JAX
  • Supports hardware acceleration through GPU and TPU (JAX).
  • Supports different types of numerical precision.
  • Has memory-efficient Autograd routines for different iterative algorithms.
  • Provides operations for both symmetric and non-symmetric matrices.
  • Runs with real and complex numbers.
  • Contains several randomized linear algebra algorithms.

Citing us

If you use CoLA, please cite the following paper:

Andres Potapczynski, Marc Finzi, Geoff Pleiss, and Andrew Gordon Wilson. "Exploiting Compositional Structure for Automatic and Efficient Numerical Linear Algebra." Pre-print (2023). Link to be added soon.

@article{potapczynski2023cola,
  title={{Exploiting Compositional Structure for Automatic and Efficient Numerical Linear Algebra}},
  author={Andres Potapczynski and Marc Finzi and Geoff Pleiss and Andrew Gordon Wilson},
  journal={Pre-print},
  year={2023}
}

Quick start guide

  1. LinearOperators. The core object in CoLA is the LinearOperator. You can add and subtract them +, -, multiply by constants *, /, matrix multiply them @ and combine them in other ways: kron, kronsum, block_diag etc.
import jax.numpy as jnp
import cola

A = cola.ops.Diagonal(jnp.arange(5) + .1)
B = cola.ops.Dense(jnp.array([[2., 1.], [-2., 1.1], [.01, .2]]))
C = B.T @ B
D = C + 0.01 * cola.ops.I_like(C)
E = cola.ops.Kronecker(A, cola.ops.Dense(jnp.ones((2, 2))))
F = cola.ops.BlockDiag(E, D)

v = jnp.ones(F.shape[-1])
print(F @ v)
[0.2       0.2       2.2       2.2       4.2       4.2       6.2
 6.2       8.2       8.2       7.8121004 2.062    ]
  1. Performing Linear Algebra. With these objects we can perform linear algebra operations even when they are very big.
print(cola.linalg.trace(F))
Q = F.T @ F + 1e-3 * cola.ops.I_like(F)
b = cola.linalg.inverse(Q) @ v
print(jnp.linalg.norm(Q @ b - v))
print(cola.linalg.eig(F)[0][:5])
print(cola.sqrt(A))
31.2701
0.0010193728
[ 2.0000000e-01+0.j  0.0000000e+00+0.j  2.1999998e+00+0.j
 -1.1920929e-07+0.j  4.1999998e+00+0.j]
diag([0.31622776 1.0488088  1.4491377  1.7606816  2.0248456 ])

For many of these functions, if we know additional information about the matrices we can annotate them to enable the algorithms to run faster.

Qs = cola.Symmetric(Q)
%timeit cola.linalg.inverse(Q)@v
%timeit cola.linalg.inverse(Qs)@v
  1. JAX and PyTorch. We support both ML frameworks.
import torch

A = cola.ops.Dense(torch.Tensor([[1., 2.], [3., 4.]]))
print(cola.linalg.trace(cola.kron(A, A)))

import jax.numpy as jnp
A = cola.ops.Dense(jnp.array([[1., 2.], [3., 4.]]))
print(cola.linalg.trace(cola.kron(A, A)))
tensor(25.)
25.0

and both support autograd (and jit):

from jax import grad, jit, vmap

def myloss(x):
    A = cola.ops.Dense(jnp.array([[1., 2.], [3., x]]))
    return jnp.ones(2) @ cola.linalg.inverse(A) @ jnp.ones(2)


g = jit(vmap(grad(myloss)))(jnp.array([.5, 10.]))
print(g)
[-0.06611571 -0.12499995]

See https://cola.readthedocs.io/en/latest/ for our full documentation and many examples.

Use cases and examples

See our examples and tutorials on how to use CoLA for different problems.

Features being added

Linear Algebra Operations

  • inverse: $A^{-1}$
  • eig: $U \Lambda U^{-1}$
  • diag
  • trace
  • exp
  • logdet
  • $f(A)$
  • SVD
  • pseudoinverse

Linear ops

  • Diag
  • BlockDiag
  • Kronecker
  • KronSum
  • Sparse
  • Jacobian
  • Hessian
  • Fisher
  • Concatenated
  • Triangular
  • FFT
  • Tridiagonal
  • CholeskyDecomposition
  • LUDecomposition
  • EigenDecomposition

Annotations

  • SelfAdjoint
  • PSD
  • Unitary

Contributing

See the contributing guidelines CONTRIBUTING.md for information on submitting issues and pull requests.

Acknowledgements

This work is supported by XXX.

Licence

CoLA is Apache 2.0 licensed.

Support and contact

Please raise an issue if you find a bug or inadequate performance when using CoLA.

About

Compositional Linear Algebra

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%