In [20]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".XX"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"]="platform"

import jax
import jax.numpy as jnp
import jax.scipy as jsp
import sympy as sym

jnp.set_printoptions(precision=2, suppress=True)

def ColumnContact(M):
    m = M.copy()
    r = m[0]
    for i in range(1,len(m)):
        r = r.row_join(m[i])
    return r

def is_in_space(space,v):
    S = space.copy()
    vect = v.copy()
    dim_S = len(S)
    S.append(vect)
    if ColumnContact(S).rank() > dim_S:
        return False
    return True

def complementary(g,s):
    G = g.copy()
    S = s.copy()
    dim_diff = len(G) - len(S)
    R = []
    for vect in G:
        if len(R) == dim_diff:
            return R
        if not is_in_space(S,vect):
            R.append(vect)
    return R

def null_space(A, eps=1e-12):
    u, s, vh = jsp.linalg.svd(A)
    padding = max(0,jnp.shape(A)[1]-jnp.shape(s)[0])
    null_mask = jnp.concatenate(((s <= eps), jnp.ones((padding,),dtype=bool)),axis=0)
    null_space = jnp.compress(null_mask, vh, axis=0)
    return jnp.transpose(null_space)
        
def jordan_decomposition(A, B):
    n = A.shape[0]
    evs = A.eigenvects()
    # _, evs = jnp.linalg.eig(B)
    num = len(evs)
    
    list_of_eigenvalues = [vals[0] for vals in evs]
    d_list_of_eigenvalues = jnp.linalg.eig(B)[0]
    list_of_algebraic = [vals[1] for vals in evs]
    list_of_matrices = [A - r * sym.eye(n) for r in list_of_eigenvalues]
    d_list_of_matrices = [B - r * jnp.eye(n) for r in d_list_of_eigenvalues]
    list_of_genmatrices = [(A - list_of_eigenvalues[i] * sym.eye(n)) ** list_of_algebraic[i] for i in range(num)]
    d_list_of_genmatrices = [B - d_list_of_eigenvalues[0] * jnp.eye(n) for i in range(num)]
    list_of_spaces = [M.nullspace() for M in list_of_matrices]
    d_list_of_spaces = [null_space(M) for M in d_list_of_matrices]
    list_of_genspaces = [M.nullspace() for M in list_of_genmatrices]
    d_list_of_genspaces = [null_space(M) for M in d_list_of_genmatrices]
    list_of_transform = []

    for k in range(num):
        
        phi = list_of_matrices[k]
        eigspace = list_of_spaces[k]
        genspace = list_of_genspaces[k]
        d_eigspace = d_list_of_spaces[k]
        d_genspace = d_list_of_genspaces[k]
        complement = complementary(genspace, eigspace)
        list_of_vij = []
        i = 1
        
        while len(complement) > 0:
            j = 1
            list_of_bunch = []
            vect = complement[0]
            while not is_in_space(list_of_spaces[k], vect):
                list_of_bunch.insert(0, vect)
                vect = phi * vect 
                j = j + 1
            eigspace = complementary(eigspace, [vect])
            complement = complementary(complement, list_of_bunch)
            list_of_bunch.insert(0, vect)
            list_of_vij.extend(list_of_bunch)
            i = i + 1
            
        if len(eigspace)>0:
            list_of_vij.extend(eigspace)
                
        list_of_transform.extend(list_of_vij)            
        
    P = ColumnContact(list_of_transform)
    J = sym.Inverse(P)*A*P
    return P, J


A = sym.Matrix([[3, -2, 4, -2], [5, 3, -3, -2], [5, -2, 2, -2], [5, -2, -3, 3]])
B = jnp.array([[3, -2, 4, -2], [5, 3, -3, -2], [5, -2, 2, -2], [5, -2, -3, 3]])

P, J = jordan_decomposition(A, B)
print("Diagonal matrix:")
display(P)
print("Nilpotent matrix:")
display(J)
display(P*J*P.inv())

Diagonal matrix:


Matrix([
[0, 1, 1,  0],
[1, 1, 1, -1],
[1, 1, 1,  0],
[1, 1, 0,  1]])

Nilpotent matrix:


Matrix([
[-2, 0, 0, 0],
[ 0, 3, 0, 0],
[ 0, 0, 5, 0],
[ 0, 0, 0, 5]])

Matrix([
[3, -2,  4, -2],
[5,  3, -3, -2],
[5, -2,  2, -2],
[5, -2, -3,  3]])