In [1]:
import torch
import matplotlib.pyplot as plt

In [2]:
eigs = torch.tensor((0, 1, 1, 2, 3, 4, 4, 4, 4))

J = torch.diag(eigs) + torch.diag(torch.tensor([0, 1, 0, 0, 0, 1, 1, 1]), diagonal=1)
J = J.float()
X = torch.rand(J.shape)

A = X @ J @ X.inverse()

---

## Special formula for Jordan block form

In [3]:
from functools import reduce
from torch import func
from math import factorial

def jordan_form_func(function):
    def split_jordan_blocks(matrix):
        """
        Split a Jordan normal form matrix into its individual Jordan blocks.

        Parameters:
            matrix (torch.Tensor): The input square Jordan normal form matrix.

        Returns:
            list: A list of tensors, each representing an individual Jordan block.
        """
        n = matrix.size(0)
        blocks = []
        start_idx = 0

        for i in range(n - 1):
            if matrix[i, i + 1] != 1:  # End of a block
                # Extract the block
                blocks.append(matrix[start_idx:i + 1, start_idx:i + 1])
                start_idx = i + 1

        # Add the last block
        if start_idx < n:
            blocks.append(matrix[start_idx:n, start_idx:n])

        return blocks

    def grad(f, n=0):
        return func.vmap(reduce(lambda f, _: torch.func.grad(f), range(n), f))

    def wrapper(input):

        input = input.reshape((1, 1)) if not input.dim() else input
        assert input.size(0) == input.size(1), "Input must be a square matrix"

        eigs = input.diagonal()

        output = function(eigs).diag()

        block_start_idx = 0
        for block in split_jordan_blocks(input):
            block_size = len(block)
            eig = block.diagonal()
            block_slice = slice(block_start_idx, block_start_idx+block_size)

            for i in range(1, block_size):
                output[block_slice, block_slice] += grad(function, i)(eig[:-i]).diag(diagonal=i) / factorial(i)

            block_start_idx += block_size

        return output

    return wrapper

In [4]:
@jordan_form_func
def f(x):
    return torch.sin(x)*torch.cos(x)


f(J)

tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000],
        [ 0.0000,  0.4546, -0.4161,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000],
        [ 0.0000,  0.0000,  0.4546,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000],
        [ 0.0000,  0.0000,  0.0000, -0.3784,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000, -0.1397,  0.0000,  0.0000,  0.0000,
          0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.4947, -0.1455, -0.9894,
          0.0970],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.4947, -0.1455,
         -0.9894],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.4947,
         -0.1455],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.4947]])

## Taylor expansion for general matrix

In [6]:
def taylor_extension(n_terms = 50, loc = 0.):
    def compute_taylor(input, function):

        input = input.reshape((1, 1)) if not input.dim() else input
        assert input.size(0) == input.size(1), "Input must be a square matrix"

        x_loc = loc * torch.ones((1, 1), requires_grad=True, dtype=torch.float64)
        matrix_loc = x_loc * torch.eye(input.size(0), dtype=torch.float64)

        factorial = torch.tensor([1.], dtype=torch.float64)
        grad_fn = function(x_loc)
        matrix_power = torch.eye(input.size(0), dtype=torch.float64)

        output = grad_fn / factorial * matrix_power

        for term in range(1, n_terms):
            
            grad_fn = torch.autograd.grad(grad_fn.sum(), x_loc, create_graph=True)[0]
            factorial *= term
            matrix_power @= input - matrix_loc 

            output += grad_fn / factorial * matrix_power

        return output

    def decorator(function):
        def wrapper(input):
            return compute_taylor(input=input, function=function)
        return wrapper

    return decorator

In [16]:
A = torch.tensor([
    [10, 1, 0],
    [0, -10, 1],
    [0, 0, 10]
], dtype=torch.float64)

x = torch.tensor(10., dtype=torch.float64)

In [30]:
(torch.linalg.matrix_exp(A) - taylor_extension(20, 0)(torch.exp)(A)).norm()

tensor(111.4385, dtype=torch.float64, grad_fn=<LinalgVectorNormBackward0>)

In [None]:
(taylor_extension(10, )(torch.exp)(x) - torch.exp(x)).norm()

tensor(0., dtype=torch.float64, grad_fn=<LinalgVectorNormBackward0>)

In [37]:
@taylor_extension(n_terms=50, loc=1)
def h(x):
    return torch.log(x)

In [None]:
x = torch.tensor(2., dtype=torch.float64)

h(x) - torch.log(x)

In [None]:
@taylor_extension(n_terms=50)
def g(x):
    return torch.sin(x)

g(torch.tensor([[1.]])) - torch.exp(-torch.tensor([[1.]]))

In [135]:
g(A)

tensor([[-0.5440, -0.0544, -0.0392],
        [ 0.0000,  0.5440, -0.0544],
        [ 0.0000,  0.0000, -0.5440]], dtype=torch.float64,
       grad_fn=<AddBackward0>)

In [148]:
B = torch.rand((100, 100), dtype=torch.float64)

(torch.linalg.matrix_exp(B) - taylor_extension(180)(torch.exp)(B)).norm()

tensor(2356777.3223, dtype=torch.float64, grad_fn=<LinalgVectorNormBackward0>)