In [1]:
import torch
import matplotlib.pyplot as plt

In [2]:
eigs = torch.tensor((0, 1, 1, 2, 3, 4, 4, 4, 4))

J = torch.diag(eigs) + torch.diag(torch.tensor([0, 1, 0, 0, 0, 1, 1, 1]), diagonal=1)
J = J.float()
X = torch.rand(J.shape)

A = X @ J @ X.inverse()

---

## Special formula for Jordan block form

In [3]:
from functools import reduce
from torch import func
from math import factorial

def jordan_form_func(function):
    def split_jordan_blocks(matrix):
        """
        Split a Jordan normal form matrix into its individual Jordan blocks.

        Parameters:
            matrix (torch.Tensor): The input square Jordan normal form matrix.

        Returns:
            list: A list of tensors, each representing an individual Jordan block.
        """
        n = matrix.size(0)
        blocks = []
        start_idx = 0

        for i in range(n - 1):
            if matrix[i, i + 1] != 1:  # End of a block
                # Extract the block
                blocks.append(matrix[start_idx:i + 1, start_idx:i + 1])
                start_idx = i + 1

        # Add the last block
        if start_idx < n:
            blocks.append(matrix[start_idx:n, start_idx:n])

        return blocks

    def grad(f, n=0):
        return func.vmap(reduce(lambda f, _: torch.func.grad(f), range(n), f))

    def wrapper(input):

        input = input.reshape((1, 1)) if not input.dim() else input
        assert input.size(0) == input.size(1), "Input must be a square matrix"

        eigs = input.diagonal()

        output = function(eigs).diag()

        block_start_idx = 0
        for block in split_jordan_blocks(input):
            block_size = len(block)
            eig = block.diagonal()
            block_slice = slice(block_start_idx, block_start_idx+block_size)

            for i in range(1, block_size):
                output[block_slice, block_slice] += grad(function, i)(eig[:-i]).diag(diagonal=i) / factorial(i)

            block_start_idx += block_size

        return output

    return wrapper

In [4]:
@jordan_form_func
def f(x):
    return torch.sin(x)*torch.cos(x)


f(J)

tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000],
        [ 0.0000,  0.4546, -0.4161,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000],
        [ 0.0000,  0.0000,  0.4546,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000],
        [ 0.0000,  0.0000,  0.0000, -0.3784,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000, -0.1397,  0.0000,  0.0000,  0.0000,
          0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.4947, -0.1455, -0.9894,
          0.0970],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.4947, -0.1455,
         -0.9894],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.4947,
         -0.1455],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.4947]])

## Taylor expansion for general matrix

In [60]:
def taylor_extension(n_terms = 30, loc = 0):
    def decorator(function):
        def wrapper(input):
            input = input.reshape((1, 1)) if not input.dim() else input
            assert input.size(0) == input.size(1), "Input must be a square matrix"

            factorial = torch.tensor([1.], dtype=torch.float64)
            x_loc = loc * torch.ones((1, 1), requires_grad=True, dtype=torch.float64)

            output = torch.zeros_like(input)
            for term in range(n_terms):

                grad_fn = torch.autograd.grad(grad_fn.sum(), x_loc, create_graph=True)[0] if term else function(x_loc)
                
                factorial = factorial*term if term else factorial*1

                # (input - x_loc*torch.eye(inputs.size(0))) for loc other than 0
                output += grad_fn / factorial * input.matrix_power(term)

            return output

        return wrapper
    return decorator

In [61]:
A = torch.tensor([
    [10, 1, 0],
    [0, -10, 1],
    [0, 0, 10]
], dtype=torch.float64)


(torch.linalg.matrix_exp(A) - taylor_extension(50)(torch.exp)(A)).norm()

tensor(1.5525e-11, dtype=torch.float64, grad_fn=<LinalgVectorNormBackward0>)

In [62]:
taylor_extension(100)(torch.sin)(torch.tensor(1.)) - torch.sin(torch.tensor(1.))

tensor([[0.]], grad_fn=<SubBackward0>)

In [70]:
@taylor_extension(n_terms=30)
def g(x):
    return torch.exp(-x)

g(torch.tensor([[1.]]))

tensor([[0.3679]], grad_fn=<AddBackward0>)

In [71]:
g(A)

tensor([[-2.7996e-03, -1.1013e+03,  5.5066e+01],
        [ 0.0000e+00,  2.2026e+04, -1.1013e+03],
        [ 0.0000e+00,  0.0000e+00, -2.7996e-03]], dtype=torch.float64,
       grad_fn=<AddBackward0>)

In [72]:
B = torch.rand((25, 25))

(torch.linalg.matrix_exp(B) - taylor_extension(5)(torch.exp)(B)).norm()

tensor(295605.1250, grad_fn=<LinalgVectorNormBackward0>)