In [2]:
import torch
import matplotlib.pyplot as plt

In [3]:
eigs = torch.tensor((0, 1, 1, 2, 3, 4, 4, 4, 4))

J = torch.diag(eigs) + torch.diag(torch.tensor([0, 1, 0, 0, 0, 1, 1, 1]), diagonal=1)
J = J.float()
X = torch.rand(J.shape)

A = X @ J @ X.inverse()

---

In [4]:
from functools import reduce
from torch import func
from math import factorial

def jordan_form_func(function):
    def split_jordan_blocks(matrix):
        """
        Split a Jordan normal form matrix into its individual Jordan blocks.

        Parameters:
            matrix (torch.Tensor): The input square Jordan normal form matrix.

        Returns:
            list: A list of tensors, each representing an individual Jordan block.
        """
        n = matrix.size(0)
        blocks = []
        start_idx = 0

        for i in range(n - 1):
            if matrix[i, i + 1] != 1:  # End of a block
                # Extract the block
                blocks.append(matrix[start_idx:i + 1, start_idx:i + 1])
                start_idx = i + 1

        # Add the last block
        if start_idx < n:
            blocks.append(matrix[start_idx:n, start_idx:n])

        return blocks

    def grad(f, n=0):
        return func.vmap(reduce(lambda f, _: torch.func.grad(f), range(n), f))

    def wrapper(input):

        input = input.reshape((1, 1)) if not input.dim() else input
        assert input.size(0) == input.size(1), "Input must be a square matrix"

        eigs = input.diagonal()

        output = function(eigs).diag()

        block_start_idx = 0
        for block in split_jordan_blocks(input):
            block_size = len(block)
            eig = block.diagonal()
            block_slice = slice(block_start_idx, block_start_idx+block_size)

            for i in range(1, block_size):
                output[block_slice, block_slice] += grad(function, i)(eig[:-i]).diag(diagonal=i) / factorial(i)

            block_start_idx += block_size

        return output
    
    return wrapper

In [5]:
@jordan_form_func
def f(x):
    return torch.sin(x)*torch.cos(x)

In [6]:
torch.set_printoptions(precision=2)

f(J)

tensor([[ 0.00,  0.00,  0.00,  0.00,  0.00,  0.00,  0.00,  0.00,  0.00],
        [ 0.00,  0.45, -0.42,  0.00,  0.00,  0.00,  0.00,  0.00,  0.00],
        [ 0.00,  0.00,  0.45,  0.00,  0.00,  0.00,  0.00,  0.00,  0.00],
        [ 0.00,  0.00,  0.00, -0.38,  0.00,  0.00,  0.00,  0.00,  0.00],
        [ 0.00,  0.00,  0.00,  0.00, -0.14,  0.00,  0.00,  0.00,  0.00],
        [ 0.00,  0.00,  0.00,  0.00,  0.00,  0.49, -0.15, -0.99,  0.10],
        [ 0.00,  0.00,  0.00,  0.00,  0.00,  0.00,  0.49, -0.15, -0.99],
        [ 0.00,  0.00,  0.00,  0.00,  0.00,  0.00,  0.00,  0.49, -0.15],
        [ 0.00,  0.00,  0.00,  0.00,  0.00,  0.00,  0.00,  0.00,  0.49]])